summaryrefslogtreecommitdiff
path: root/test/support
diff options
context:
space:
mode:
Diffstat (limited to 'test/support')
-rw-r--r--test/support/integration/plugins/cache/jsonfile.py63
-rw-r--r--test/support/integration/plugins/filter/json_query.py53
-rw-r--r--test/support/integration/plugins/inventory/aws_ec2.py760
-rw-r--r--test/support/integration/plugins/inventory/docker_swarm.py351
-rw-r--r--test/support/integration/plugins/inventory/foreman.py295
-rw-r--r--test/support/integration/plugins/lookup/rabbitmq.py190
-rw-r--r--test/support/integration/plugins/module_utils/aws/__init__.py0
-rw-r--r--test/support/integration/plugins/module_utils/aws/core.py335
-rw-r--r--test/support/integration/plugins/module_utils/aws/iam.py49
-rw-r--r--test/support/integration/plugins/module_utils/aws/s3.py50
-rw-r--r--test/support/integration/plugins/module_utils/aws/waiters.py405
-rw-r--r--test/support/integration/plugins/module_utils/azure_rm_common.py1473
-rw-r--r--test/support/integration/plugins/module_utils/azure_rm_common_rest.py97
-rw-r--r--test/support/integration/plugins/module_utils/cloud.py217
-rw-r--r--test/support/integration/plugins/module_utils/compat/__init__.py0
-rw-r--r--test/support/integration/plugins/module_utils/compat/ipaddress.py2476
-rw-r--r--test/support/integration/plugins/module_utils/crypto.py2125
-rw-r--r--test/support/integration/plugins/module_utils/database.py142
-rw-r--r--test/support/integration/plugins/module_utils/docker/__init__.py0
-rw-r--r--test/support/integration/plugins/module_utils/docker/common.py1022
-rw-r--r--test/support/integration/plugins/module_utils/docker/swarm.py280
-rw-r--r--test/support/integration/plugins/module_utils/ec2.py758
-rw-r--r--test/support/integration/plugins/module_utils/ecs/__init__.py0
-rw-r--r--test/support/integration/plugins/module_utils/ecs/api.py364
-rw-r--r--test/support/integration/plugins/module_utils/mysql.py106
-rw-r--r--test/support/integration/plugins/module_utils/net_tools/__init__.py0
-rw-r--r--test/support/integration/plugins/module_utils/network/__init__.py0
-rw-r--r--test/support/integration/plugins/module_utils/network/common/__init__.py0
-rw-r--r--test/support/integration/plugins/module_utils/network/common/utils.py643
-rw-r--r--test/support/integration/plugins/module_utils/postgres.py330
-rw-r--r--test/support/integration/plugins/module_utils/rabbitmq.py220
l---------test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py1
l---------test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py1
l---------test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py1
l---------test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py1
l---------test/support/integration/plugins/modules/_azure_rm_resource_facts.py1
l---------test/support/integration/plugins/modules/_azure_rm_webapp_facts.py1
-rw-r--r--test/support/integration/plugins/modules/aws_az_info.py111
-rw-r--r--test/support/integration/plugins/modules/aws_s3.py925
-rw-r--r--test/support/integration/plugins/modules/azure_rm_appserviceplan.py379
-rw-r--r--test/support/integration/plugins/modules/azure_rm_functionapp.py421
-rw-r--r--test/support/integration/plugins/modules/azure_rm_functionapp_info.py207
-rw-r--r--test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py241
-rw-r--r--test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py217
-rw-r--r--test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py304
-rw-r--r--test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py212
-rw-r--r--test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py277
-rw-r--r--test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py208
-rw-r--r--test/support/integration/plugins/modules/azure_rm_mariadbserver.py388
-rw-r--r--test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py265
-rw-r--r--test/support/integration/plugins/modules/azure_rm_resource.py427
-rw-r--r--test/support/integration/plugins/modules/azure_rm_resource_info.py432
-rw-r--r--test/support/integration/plugins/modules/azure_rm_storageaccount.py684
-rw-r--r--test/support/integration/plugins/modules/azure_rm_webapp.py1070
-rw-r--r--test/support/integration/plugins/modules/azure_rm_webapp_info.py489
-rw-r--r--test/support/integration/plugins/modules/azure_rm_webappslot.py1058
-rw-r--r--test/support/integration/plugins/modules/cloud_init_data_facts.py134
-rw-r--r--test/support/integration/plugins/modules/cloudformation.py837
-rw-r--r--test/support/integration/plugins/modules/cloudformation_info.py355
-rw-r--r--test/support/integration/plugins/modules/deploy_helper.py521
-rw-r--r--test/support/integration/plugins/modules/docker_swarm.py681
-rw-r--r--test/support/integration/plugins/modules/ec2.py1766
-rw-r--r--test/support/integration/plugins/modules/ec2_ami_info.py282
-rw-r--r--test/support/integration/plugins/modules/ec2_group.py1345
-rw-r--r--test/support/integration/plugins/modules/ec2_vpc_net.py524
-rw-r--r--test/support/integration/plugins/modules/ec2_vpc_subnet.py604
-rw-r--r--test/support/integration/plugins/modules/flatpak_remote.py243
-rw-r--r--test/support/integration/plugins/modules/htpasswd.py275
-rw-r--r--test/support/integration/plugins/modules/locale_gen.py237
-rw-r--r--test/support/integration/plugins/modules/lvg.py295
-rw-r--r--test/support/integration/plugins/modules/mongodb_parameter.py223
-rw-r--r--test/support/integration/plugins/modules/mongodb_user.py474
-rw-r--r--test/support/integration/plugins/modules/pids.py89
-rw-r--r--test/support/integration/plugins/modules/pkgng.py406
-rw-r--r--test/support/integration/plugins/modules/postgresql_db.py657
-rw-r--r--test/support/integration/plugins/modules/postgresql_privs.py1097
-rw-r--r--test/support/integration/plugins/modules/postgresql_query.py364
-rw-r--r--test/support/integration/plugins/modules/postgresql_set.py434
-rw-r--r--test/support/integration/plugins/modules/postgresql_table.py601
-rw-r--r--test/support/integration/plugins/modules/postgresql_user.py927
-rw-r--r--test/support/integration/plugins/modules/rabbitmq_plugin.py180
-rw-r--r--test/support/integration/plugins/modules/rabbitmq_queue.py257
-rw-r--r--test/support/integration/plugins/modules/s3_bucket.py740
-rw-r--r--test/support/integration/plugins/modules/sefcontext.py298
-rw-r--r--test/support/integration/plugins/modules/selogin.py260
-rw-r--r--test/support/integration/plugins/modules/synchronize.py618
-rw-r--r--test/support/integration/plugins/modules/timezone.py909
-rw-r--r--test/support/integration/plugins/modules/x509_crl.py783
-rw-r--r--test/support/integration/plugins/modules/x509_crl_info.py281
-rw-r--r--test/support/integration/plugins/modules/xml.py966
-rw-r--r--test/support/integration/plugins/modules/zypper.py540
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py40
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py90
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py199
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py235
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py209
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py42
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py324
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py404
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py924
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py97
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py66
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py14
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py1186
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py531
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py91
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py2578
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py27
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py473
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py162
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py179
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py275
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py316
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py686
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py147
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py61
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py444
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py71
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py82
-rw-r--r--test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py70
-rw-r--r--test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py133
-rw-r--r--test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py465
-rw-r--r--test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py81
-rw-r--r--test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py197
-rw-r--r--test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py229
-rw-r--r--test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py596
-rw-r--r--test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py115
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py129
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py342
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py63
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py22
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py263
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py69
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py81
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py80
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py56
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py89
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py99
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py438
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py83
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py380
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py134
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py143
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py152
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py162
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py116
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py155
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py181
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py231
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py124
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py223
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py354
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py174
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py513
-rw-r--r--test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py53
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py1
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps11
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps11
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py1
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps11
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py1
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps11
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py1
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps11
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py1
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps11
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py1
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps11
l---------test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py1
-rw-r--r--test/support/windows-integration/plugins/action/win_copy.py522
-rw-r--r--test/support/windows-integration/plugins/action/win_reboot.py96
-rw-r--r--test/support/windows-integration/plugins/action/win_template.py29
-rw-r--r--test/support/windows-integration/plugins/become/runas.py70
-rw-r--r--test/support/windows-integration/plugins/module_utils/Ansible.Service.cs1341
-rw-r--r--test/support/windows-integration/plugins/modules/async_status.ps158
-rw-r--r--test/support/windows-integration/plugins/modules/setup.ps1516
-rw-r--r--test/support/windows-integration/plugins/modules/slurp.ps128
-rw-r--r--test/support/windows-integration/plugins/modules/win_acl.ps1225
-rw-r--r--test/support/windows-integration/plugins/modules/win_acl.py132
-rw-r--r--test/support/windows-integration/plugins/modules/win_certificate_store.ps1260
-rw-r--r--test/support/windows-integration/plugins/modules/win_certificate_store.py208
-rw-r--r--test/support/windows-integration/plugins/modules/win_command.ps178
-rw-r--r--test/support/windows-integration/plugins/modules/win_command.py136
-rw-r--r--test/support/windows-integration/plugins/modules/win_copy.ps1403
-rw-r--r--test/support/windows-integration/plugins/modules/win_copy.py207
-rw-r--r--test/support/windows-integration/plugins/modules/win_data_deduplication.ps1129
-rw-r--r--test/support/windows-integration/plugins/modules/win_data_deduplication.py87
-rw-r--r--test/support/windows-integration/plugins/modules/win_dsc.ps1398
-rw-r--r--test/support/windows-integration/plugins/modules/win_dsc.py183
-rw-r--r--test/support/windows-integration/plugins/modules/win_feature.ps1111
-rw-r--r--test/support/windows-integration/plugins/modules/win_feature.py149
-rw-r--r--test/support/windows-integration/plugins/modules/win_file.ps1152
-rw-r--r--test/support/windows-integration/plugins/modules/win_file.py70
-rw-r--r--test/support/windows-integration/plugins/modules/win_find.ps1416
-rw-r--r--test/support/windows-integration/plugins/modules/win_find.py345
-rw-r--r--test/support/windows-integration/plugins/modules/win_format.ps1200
-rw-r--r--test/support/windows-integration/plugins/modules/win_format.py103
-rw-r--r--test/support/windows-integration/plugins/modules/win_get_url.ps1274
-rw-r--r--test/support/windows-integration/plugins/modules/win_get_url.py215
-rw-r--r--test/support/windows-integration/plugins/modules/win_lineinfile.ps1450
-rw-r--r--test/support/windows-integration/plugins/modules/win_lineinfile.py180
-rw-r--r--test/support/windows-integration/plugins/modules/win_path.ps1145
-rw-r--r--test/support/windows-integration/plugins/modules/win_path.py79
-rw-r--r--test/support/windows-integration/plugins/modules/win_ping.ps121
-rw-r--r--test/support/windows-integration/plugins/modules/win_ping.py55
-rw-r--r--test/support/windows-integration/plugins/modules/win_psexec.ps1152
-rw-r--r--test/support/windows-integration/plugins/modules/win_psexec.py172
-rw-r--r--test/support/windows-integration/plugins/modules/win_reboot.py131
-rw-r--r--test/support/windows-integration/plugins/modules/win_regedit.ps1495
-rw-r--r--test/support/windows-integration/plugins/modules/win_regedit.py210
-rw-r--r--test/support/windows-integration/plugins/modules/win_security_policy.ps1196
-rw-r--r--test/support/windows-integration/plugins/modules/win_security_policy.py126
-rw-r--r--test/support/windows-integration/plugins/modules/win_shell.ps1138
-rw-r--r--test/support/windows-integration/plugins/modules/win_shell.py167
-rw-r--r--test/support/windows-integration/plugins/modules/win_stat.ps1186
-rw-r--r--test/support/windows-integration/plugins/modules/win_stat.py236
-rw-r--r--test/support/windows-integration/plugins/modules/win_tempfile.ps172
-rw-r--r--test/support/windows-integration/plugins/modules/win_tempfile.py67
-rw-r--r--test/support/windows-integration/plugins/modules/win_template.py66
-rw-r--r--test/support/windows-integration/plugins/modules/win_user.ps1273
-rw-r--r--test/support/windows-integration/plugins/modules/win_user.py194
-rw-r--r--test/support/windows-integration/plugins/modules/win_user_right.ps1349
-rw-r--r--test/support/windows-integration/plugins/modules/win_user_right.py108
-rw-r--r--test/support/windows-integration/plugins/modules/win_wait_for.ps1259
-rw-r--r--test/support/windows-integration/plugins/modules/win_wait_for.py155
-rw-r--r--test/support/windows-integration/plugins/modules/win_whoami.ps1837
-rw-r--r--test/support/windows-integration/plugins/modules/win_whoami.py203
227 files changed, 69953 insertions, 0 deletions
diff --git a/test/support/integration/plugins/cache/jsonfile.py b/test/support/integration/plugins/cache/jsonfile.py
new file mode 100644
index 00000000..80b16f55
--- /dev/null
+++ b/test/support/integration/plugins/cache/jsonfile.py
@@ -0,0 +1,63 @@
+# (c) 2014, Brian Coca, Josh Drake, et al
+# (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+DOCUMENTATION = '''
+ cache: jsonfile
+ short_description: JSON formatted files.
+ description:
+ - This cache uses JSON formatted, per host, files saved to the filesystem.
+ version_added: "1.9"
+ author: Ansible Core (@ansible-core)
+ options:
+ _uri:
+ required: True
+ description:
+ - Path in which the cache plugin will save the JSON files
+ env:
+ - name: ANSIBLE_CACHE_PLUGIN_CONNECTION
+ ini:
+ - key: fact_caching_connection
+ section: defaults
+ _prefix:
+ description: User defined prefix to use when creating the JSON files
+ env:
+ - name: ANSIBLE_CACHE_PLUGIN_PREFIX
+ ini:
+ - key: fact_caching_prefix
+ section: defaults
+ _timeout:
+ default: 86400
+ description: Expiration timeout in seconds for the cache plugin data. Set to 0 to never expire
+ env:
+ - name: ANSIBLE_CACHE_PLUGIN_TIMEOUT
+ ini:
+ - key: fact_caching_timeout
+ section: defaults
+ type: integer
+'''
+
+import codecs
+import json
+
+from ansible.parsing.ajson import AnsibleJSONEncoder, AnsibleJSONDecoder
+from ansible.plugins.cache import BaseFileCacheModule
+
+
+class CacheModule(BaseFileCacheModule):
+ """
+ A caching module backed by json files.
+ """
+
+ def _load(self, filepath):
+ # Valid JSON is always UTF-8 encoded.
+ with codecs.open(filepath, 'r', encoding='utf-8') as f:
+ return json.load(f, cls=AnsibleJSONDecoder)
+
+ def _dump(self, value, filepath):
+ with codecs.open(filepath, 'w', encoding='utf-8') as f:
+ f.write(json.dumps(value, cls=AnsibleJSONEncoder, sort_keys=True, indent=4))
diff --git a/test/support/integration/plugins/filter/json_query.py b/test/support/integration/plugins/filter/json_query.py
new file mode 100644
index 00000000..d1da71b4
--- /dev/null
+++ b/test/support/integration/plugins/filter/json_query.py
@@ -0,0 +1,53 @@
+# (c) 2015, Filipe Niero Felisbino <filipenf@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from ansible.errors import AnsibleError, AnsibleFilterError
+
+try:
+ import jmespath
+ HAS_LIB = True
+except ImportError:
+ HAS_LIB = False
+
+
+def json_query(data, expr):
+ '''Query data using jmespath query language ( http://jmespath.org ). Example:
+ - debug: msg="{{ instance | json_query(tagged_instances[*].block_device_mapping.*.volume_id') }}"
+ '''
+ if not HAS_LIB:
+ raise AnsibleError('You need to install "jmespath" prior to running '
+ 'json_query filter')
+
+ try:
+ return jmespath.search(expr, data)
+ except jmespath.exceptions.JMESPathError as e:
+ raise AnsibleFilterError('JMESPathError in json_query filter plugin:\n%s' % e)
+ except Exception as e:
+ # For older jmespath, we can get ValueError and TypeError without much info.
+ raise AnsibleFilterError('Error in jmespath.search in json_query filter plugin:\n%s' % e)
+
+
+class FilterModule(object):
+ ''' Query filter '''
+
+ def filters(self):
+ return {
+ 'json_query': json_query
+ }
diff --git a/test/support/integration/plugins/inventory/aws_ec2.py b/test/support/integration/plugins/inventory/aws_ec2.py
new file mode 100644
index 00000000..09c42cf9
--- /dev/null
+++ b/test/support/integration/plugins/inventory/aws_ec2.py
@@ -0,0 +1,760 @@
+# Copyright (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+DOCUMENTATION = '''
+ name: aws_ec2
+ plugin_type: inventory
+ short_description: EC2 inventory source
+ requirements:
+ - boto3
+ - botocore
+ extends_documentation_fragment:
+ - inventory_cache
+ - constructed
+ description:
+ - Get inventory hosts from Amazon Web Services EC2.
+ - Uses a YAML configuration file that ends with C(aws_ec2.(yml|yaml)).
+ notes:
+ - If no credentials are provided and the control node has an associated IAM instance profile then the
+ role will be used for authentication.
+ author:
+ - Sloane Hertel (@s-hertel)
+ options:
+ aws_profile:
+ description: The AWS profile
+ type: str
+ aliases: [ boto_profile ]
+ env:
+ - name: AWS_DEFAULT_PROFILE
+ - name: AWS_PROFILE
+ aws_access_key:
+ description: The AWS access key to use.
+ type: str
+ aliases: [ aws_access_key_id ]
+ env:
+ - name: EC2_ACCESS_KEY
+ - name: AWS_ACCESS_KEY
+ - name: AWS_ACCESS_KEY_ID
+ aws_secret_key:
+ description: The AWS secret key that corresponds to the access key.
+ type: str
+ aliases: [ aws_secret_access_key ]
+ env:
+ - name: EC2_SECRET_KEY
+ - name: AWS_SECRET_KEY
+ - name: AWS_SECRET_ACCESS_KEY
+ aws_security_token:
+ description: The AWS security token if using temporary access and secret keys.
+ type: str
+ env:
+ - name: EC2_SECURITY_TOKEN
+ - name: AWS_SESSION_TOKEN
+ - name: AWS_SECURITY_TOKEN
+ plugin:
+ description: Token that ensures this is a source file for the plugin.
+ required: True
+ choices: ['aws_ec2']
+ iam_role_arn:
+ description: The ARN of the IAM role to assume to perform the inventory lookup. You should still provide AWS
+ credentials with enough privilege to perform the AssumeRole action.
+ version_added: '2.9'
+ regions:
+ description:
+ - A list of regions in which to describe EC2 instances.
+ - If empty (the default) default this will include all regions, except possibly restricted ones like us-gov-west-1 and cn-north-1.
+ type: list
+ default: []
+ hostnames:
+ description:
+ - A list in order of precedence for hostname variables.
+ - You can use the options specified in U(http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options).
+ - To use tags as hostnames use the syntax tag:Name=Value to use the hostname Name_Value, or tag:Name to use the value of the Name tag.
+ type: list
+ default: []
+ filters:
+ description:
+ - A dictionary of filter value pairs.
+ - Available filters are listed here U(http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options).
+ type: dict
+ default: {}
+ include_extra_api_calls:
+ description:
+ - Add two additional API calls for every instance to include 'persistent' and 'events' host variables.
+ - Spot instances may be persistent and instances may have associated events.
+ type: bool
+ default: False
+ version_added: '2.8'
+ strict_permissions:
+ description:
+ - By default if a 403 (Forbidden) error code is encountered this plugin will fail.
+ - You can set this option to False in the inventory config file which will allow 403 errors to be gracefully skipped.
+ type: bool
+ default: True
+ use_contrib_script_compatible_sanitization:
+ description:
+ - By default this plugin is using a general group name sanitization to create safe and usable group names for use in Ansible.
+ This option allows you to override that, in efforts to allow migration from the old inventory script and
+ matches the sanitization of groups when the script's ``replace_dash_in_groups`` option is set to ``False``.
+ To replicate behavior of ``replace_dash_in_groups = True`` with constructed groups,
+ you will need to replace hyphens with underscores via the regex_replace filter for those entries.
+ - For this to work you should also turn off the TRANSFORM_INVALID_GROUP_CHARS setting,
+ otherwise the core engine will just use the standard sanitization on top.
+ - This is not the default as such names break certain functionality as not all characters are valid Python identifiers
+ which group names end up being used as.
+ type: bool
+ default: False
+ version_added: '2.8'
+'''
+
+EXAMPLES = '''
+# Minimal example using environment vars or instance role credentials
+# Fetch all hosts in us-east-1, the hostname is the public DNS if it exists, otherwise the private IP address
+plugin: aws_ec2
+regions:
+ - us-east-1
+
+# Example using filters, ignoring permission errors, and specifying the hostname precedence
+plugin: aws_ec2
+boto_profile: aws_profile
+# Populate inventory with instances in these regions
+regions:
+ - us-east-1
+ - us-east-2
+filters:
+ # All instances with their `Environment` tag set to `dev`
+ tag:Environment: dev
+ # All dev and QA hosts
+ tag:Environment:
+ - dev
+ - qa
+ instance.group-id: sg-xxxxxxxx
+# Ignores 403 errors rather than failing
+strict_permissions: False
+# Note: I(hostnames) sets the inventory_hostname. To modify ansible_host without modifying
+# inventory_hostname use compose (see example below).
+hostnames:
+ - tag:Name=Tag1,Name=Tag2 # Return specific hosts only
+ - tag:CustomDNSName
+ - dns-name
+ - private-ip-address
+
+# Example using constructed features to create groups and set ansible_host
+plugin: aws_ec2
+regions:
+ - us-east-1
+ - us-west-1
+# keyed_groups may be used to create custom groups
+strict: False
+keyed_groups:
+ # Add e.g. x86_64 hosts to an arch_x86_64 group
+ - prefix: arch
+ key: 'architecture'
+ # Add hosts to tag_Name_Value groups for each Name/Value tag pair
+ - prefix: tag
+ key: tags
+ # Add hosts to e.g. instance_type_z3_tiny
+ - prefix: instance_type
+ key: instance_type
+ # Create security_groups_sg_abcd1234 group for each SG
+ - key: 'security_groups|json_query("[].group_id")'
+ prefix: 'security_groups'
+ # Create a group for each value of the Application tag
+ - key: tags.Application
+ separator: ''
+ # Create a group per region e.g. aws_region_us_east_2
+ - key: placement.region
+ prefix: aws_region
+ # Create a group (or groups) based on the value of a custom tag "Role" and add them to a metagroup called "project"
+ - key: tags['Role']
+ prefix: foo
+ parent_group: "project"
+# Set individual variables with compose
+compose:
+ # Use the private IP address to connect to the host
+ # (note: this does not modify inventory_hostname, which is set via I(hostnames))
+ ansible_host: private_ip_address
+'''
+
+import re
+
+from ansible.errors import AnsibleError
+from ansible.module_utils._text import to_native, to_text
+from ansible.module_utils.common.dict_transformations import camel_dict_to_snake_dict
+from ansible.plugins.inventory import BaseInventoryPlugin, Constructable, Cacheable
+from ansible.utils.display import Display
+from ansible.module_utils.six import string_types
+
+try:
+ import boto3
+ import botocore
+except ImportError:
+ raise AnsibleError('The ec2 dynamic inventory plugin requires boto3 and botocore.')
+
+display = Display()
+
+# The mappings give an array of keys to get from the filter name to the value
+# returned by boto3's EC2 describe_instances method.
+
+instance_meta_filter_to_boto_attr = {
+ 'group-id': ('Groups', 'GroupId'),
+ 'group-name': ('Groups', 'GroupName'),
+ 'network-interface.attachment.instance-owner-id': ('OwnerId',),
+ 'owner-id': ('OwnerId',),
+ 'requester-id': ('RequesterId',),
+ 'reservation-id': ('ReservationId',),
+}
+
+instance_data_filter_to_boto_attr = {
+ 'affinity': ('Placement', 'Affinity'),
+ 'architecture': ('Architecture',),
+ 'availability-zone': ('Placement', 'AvailabilityZone'),
+ 'block-device-mapping.attach-time': ('BlockDeviceMappings', 'Ebs', 'AttachTime'),
+ 'block-device-mapping.delete-on-termination': ('BlockDeviceMappings', 'Ebs', 'DeleteOnTermination'),
+ 'block-device-mapping.device-name': ('BlockDeviceMappings', 'DeviceName'),
+ 'block-device-mapping.status': ('BlockDeviceMappings', 'Ebs', 'Status'),
+ 'block-device-mapping.volume-id': ('BlockDeviceMappings', 'Ebs', 'VolumeId'),
+ 'client-token': ('ClientToken',),
+ 'dns-name': ('PublicDnsName',),
+ 'host-id': ('Placement', 'HostId'),
+ 'hypervisor': ('Hypervisor',),
+ 'iam-instance-profile.arn': ('IamInstanceProfile', 'Arn'),
+ 'image-id': ('ImageId',),
+ 'instance-id': ('InstanceId',),
+ 'instance-lifecycle': ('InstanceLifecycle',),
+ 'instance-state-code': ('State', 'Code'),
+ 'instance-state-name': ('State', 'Name'),
+ 'instance-type': ('InstanceType',),
+ 'instance.group-id': ('SecurityGroups', 'GroupId'),
+ 'instance.group-name': ('SecurityGroups', 'GroupName'),
+ 'ip-address': ('PublicIpAddress',),
+ 'kernel-id': ('KernelId',),
+ 'key-name': ('KeyName',),
+ 'launch-index': ('AmiLaunchIndex',),
+ 'launch-time': ('LaunchTime',),
+ 'monitoring-state': ('Monitoring', 'State'),
+ 'network-interface.addresses.private-ip-address': ('NetworkInterfaces', 'PrivateIpAddress'),
+ 'network-interface.addresses.primary': ('NetworkInterfaces', 'PrivateIpAddresses', 'Primary'),
+ 'network-interface.addresses.association.public-ip': ('NetworkInterfaces', 'PrivateIpAddresses', 'Association', 'PublicIp'),
+ 'network-interface.addresses.association.ip-owner-id': ('NetworkInterfaces', 'PrivateIpAddresses', 'Association', 'IpOwnerId'),
+ 'network-interface.association.public-ip': ('NetworkInterfaces', 'Association', 'PublicIp'),
+ 'network-interface.association.ip-owner-id': ('NetworkInterfaces', 'Association', 'IpOwnerId'),
+ 'network-interface.association.allocation-id': ('ElasticGpuAssociations', 'ElasticGpuId'),
+ 'network-interface.association.association-id': ('ElasticGpuAssociations', 'ElasticGpuAssociationId'),
+ 'network-interface.attachment.attachment-id': ('NetworkInterfaces', 'Attachment', 'AttachmentId'),
+ 'network-interface.attachment.instance-id': ('InstanceId',),
+ 'network-interface.attachment.device-index': ('NetworkInterfaces', 'Attachment', 'DeviceIndex'),
+ 'network-interface.attachment.status': ('NetworkInterfaces', 'Attachment', 'Status'),
+ 'network-interface.attachment.attach-time': ('NetworkInterfaces', 'Attachment', 'AttachTime'),
+ 'network-interface.attachment.delete-on-termination': ('NetworkInterfaces', 'Attachment', 'DeleteOnTermination'),
+ 'network-interface.availability-zone': ('Placement', 'AvailabilityZone'),
+ 'network-interface.description': ('NetworkInterfaces', 'Description'),
+ 'network-interface.group-id': ('NetworkInterfaces', 'Groups', 'GroupId'),
+ 'network-interface.group-name': ('NetworkInterfaces', 'Groups', 'GroupName'),
+ 'network-interface.ipv6-addresses.ipv6-address': ('NetworkInterfaces', 'Ipv6Addresses', 'Ipv6Address'),
+ 'network-interface.mac-address': ('NetworkInterfaces', 'MacAddress'),
+ 'network-interface.network-interface-id': ('NetworkInterfaces', 'NetworkInterfaceId'),
+ 'network-interface.owner-id': ('NetworkInterfaces', 'OwnerId'),
+ 'network-interface.private-dns-name': ('NetworkInterfaces', 'PrivateDnsName'),
+ # 'network-interface.requester-id': (),
+ 'network-interface.requester-managed': ('NetworkInterfaces', 'Association', 'IpOwnerId'),
+ 'network-interface.status': ('NetworkInterfaces', 'Status'),
+ 'network-interface.source-dest-check': ('NetworkInterfaces', 'SourceDestCheck'),
+ 'network-interface.subnet-id': ('NetworkInterfaces', 'SubnetId'),
+ 'network-interface.vpc-id': ('NetworkInterfaces', 'VpcId'),
+ 'placement-group-name': ('Placement', 'GroupName'),
+ 'platform': ('Platform',),
+ 'private-dns-name': ('PrivateDnsName',),
+ 'private-ip-address': ('PrivateIpAddress',),
+ 'product-code': ('ProductCodes', 'ProductCodeId'),
+ 'product-code.type': ('ProductCodes', 'ProductCodeType'),
+ 'ramdisk-id': ('RamdiskId',),
+ 'reason': ('StateTransitionReason',),
+ 'root-device-name': ('RootDeviceName',),
+ 'root-device-type': ('RootDeviceType',),
+ 'source-dest-check': ('SourceDestCheck',),
+ 'spot-instance-request-id': ('SpotInstanceRequestId',),
+ 'state-reason-code': ('StateReason', 'Code'),
+ 'state-reason-message': ('StateReason', 'Message'),
+ 'subnet-id': ('SubnetId',),
+ 'tag': ('Tags',),
+ 'tag-key': ('Tags',),
+ 'tag-value': ('Tags',),
+ 'tenancy': ('Placement', 'Tenancy'),
+ 'virtualization-type': ('VirtualizationType',),
+ 'vpc-id': ('VpcId',),
+}
+
+
+class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
+
+ NAME = 'aws_ec2'
+
+ def __init__(self):
+ super(InventoryModule, self).__init__()
+
+ self.group_prefix = 'aws_ec2_'
+
+ # credentials
+ self.boto_profile = None
+ self.aws_secret_access_key = None
+ self.aws_access_key_id = None
+ self.aws_security_token = None
+ self.iam_role_arn = None
+
+ def _compile_values(self, obj, attr):
+ '''
+ :param obj: A list or dict of instance attributes
+ :param attr: A key
+ :return The value(s) found via the attr
+ '''
+ if obj is None:
+ return
+
+ temp_obj = []
+
+ if isinstance(obj, list) or isinstance(obj, tuple):
+ for each in obj:
+ value = self._compile_values(each, attr)
+ if value:
+ temp_obj.append(value)
+ else:
+ temp_obj = obj.get(attr)
+
+ has_indexes = any([isinstance(temp_obj, list), isinstance(temp_obj, tuple)])
+ if has_indexes and len(temp_obj) == 1:
+ return temp_obj[0]
+
+ return temp_obj
+
+ def _get_boto_attr_chain(self, filter_name, instance):
+ '''
+ :param filter_name: The filter
+ :param instance: instance dict returned by boto3 ec2 describe_instances()
+ '''
+ allowed_filters = sorted(list(instance_data_filter_to_boto_attr.keys()) + list(instance_meta_filter_to_boto_attr.keys()))
+ if filter_name not in allowed_filters:
+ raise AnsibleError("Invalid filter '%s' provided; filter must be one of %s." % (filter_name,
+ allowed_filters))
+ if filter_name in instance_data_filter_to_boto_attr:
+ boto_attr_list = instance_data_filter_to_boto_attr[filter_name]
+ else:
+ boto_attr_list = instance_meta_filter_to_boto_attr[filter_name]
+
+ instance_value = instance
+ for attribute in boto_attr_list:
+ instance_value = self._compile_values(instance_value, attribute)
+ return instance_value
+
+ def _get_credentials(self):
+ '''
+ :return A dictionary of boto client credentials
+ '''
+ boto_params = {}
+ for credential in (('aws_access_key_id', self.aws_access_key_id),
+ ('aws_secret_access_key', self.aws_secret_access_key),
+ ('aws_session_token', self.aws_security_token)):
+ if credential[1]:
+ boto_params[credential[0]] = credential[1]
+
+ return boto_params
+
+ def _get_connection(self, credentials, region='us-east-1'):
+ try:
+ connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region, **credentials)
+ except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e:
+ if self.boto_profile:
+ try:
+ connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region)
+ except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e:
+ raise AnsibleError("Insufficient credentials found: %s" % to_native(e))
+ else:
+ raise AnsibleError("Insufficient credentials found: %s" % to_native(e))
+ return connection
+
+ def _boto3_assume_role(self, credentials, region):
+ """
+ Assume an IAM role passed by iam_role_arn parameter
+
+ :return: a dict containing the credentials of the assumed role
+ """
+
+ iam_role_arn = self.iam_role_arn
+
+ try:
+ sts_connection = boto3.session.Session(profile_name=self.boto_profile).client('sts', region, **credentials)
+ sts_session = sts_connection.assume_role(RoleArn=iam_role_arn, RoleSessionName='ansible_aws_ec2_dynamic_inventory')
+ return dict(
+ aws_access_key_id=sts_session['Credentials']['AccessKeyId'],
+ aws_secret_access_key=sts_session['Credentials']['SecretAccessKey'],
+ aws_session_token=sts_session['Credentials']['SessionToken']
+ )
+ except botocore.exceptions.ClientError as e:
+ raise AnsibleError("Unable to assume IAM role: %s" % to_native(e))
+
+ def _boto3_conn(self, regions):
+ '''
+ :param regions: A list of regions to create a boto3 client
+
+ Generator that yields a boto3 client and the region
+ '''
+
+ credentials = self._get_credentials()
+ iam_role_arn = self.iam_role_arn
+
+ if not regions:
+ try:
+ # as per https://boto3.amazonaws.com/v1/documentation/api/latest/guide/ec2-example-regions-avail-zones.html
+ client = self._get_connection(credentials)
+ resp = client.describe_regions()
+ regions = [x['RegionName'] for x in resp.get('Regions', [])]
+ except botocore.exceptions.NoRegionError:
+ # above seems to fail depending on boto3 version, ignore and lets try something else
+ pass
+
+ # fallback to local list hardcoded in boto3 if still no regions
+ if not regions:
+ session = boto3.Session()
+ regions = session.get_available_regions('ec2')
+
+ # I give up, now you MUST give me regions
+ if not regions:
+ raise AnsibleError('Unable to get regions list from available methods, you must specify the "regions" option to continue.')
+
+ for region in regions:
+ connection = self._get_connection(credentials, region)
+ try:
+ if iam_role_arn is not None:
+ assumed_credentials = self._boto3_assume_role(credentials, region)
+ else:
+ assumed_credentials = credentials
+ connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region, **assumed_credentials)
+ except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e:
+ if self.boto_profile:
+ try:
+ connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region)
+ except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e:
+ raise AnsibleError("Insufficient credentials found: %s" % to_native(e))
+ else:
+ raise AnsibleError("Insufficient credentials found: %s" % to_native(e))
+ yield connection, region
+
+ def _get_instances_by_region(self, regions, filters, strict_permissions):
+ '''
+ :param regions: a list of regions in which to describe instances
+ :param filters: a list of boto3 filter dictionaries
+ :param strict_permissions: a boolean determining whether to fail or ignore 403 error codes
+ :return A list of instance dictionaries
+ '''
+ all_instances = []
+
+ for connection, region in self._boto3_conn(regions):
+ try:
+ # By default find non-terminated/terminating instances
+ if not any([f['Name'] == 'instance-state-name' for f in filters]):
+ filters.append({'Name': 'instance-state-name', 'Values': ['running', 'pending', 'stopping', 'stopped']})
+ paginator = connection.get_paginator('describe_instances')
+ reservations = paginator.paginate(Filters=filters).build_full_result().get('Reservations')
+ instances = []
+ for r in reservations:
+ new_instances = r['Instances']
+ for instance in new_instances:
+ instance.update(self._get_reservation_details(r))
+ if self.get_option('include_extra_api_calls'):
+ instance.update(self._get_event_set_and_persistence(connection, instance['InstanceId'], instance.get('SpotInstanceRequestId')))
+ instances.extend(new_instances)
+ except botocore.exceptions.ClientError as e:
+ if e.response['ResponseMetadata']['HTTPStatusCode'] == 403 and not strict_permissions:
+ instances = []
+ else:
+ raise AnsibleError("Failed to describe instances: %s" % to_native(e))
+ except botocore.exceptions.BotoCoreError as e:
+ raise AnsibleError("Failed to describe instances: %s" % to_native(e))
+
+ all_instances.extend(instances)
+
+ return sorted(all_instances, key=lambda x: x['InstanceId'])
+
+ def _get_reservation_details(self, reservation):
+ return {
+ 'OwnerId': reservation['OwnerId'],
+ 'RequesterId': reservation.get('RequesterId', ''),
+ 'ReservationId': reservation['ReservationId']
+ }
+
+ def _get_event_set_and_persistence(self, connection, instance_id, spot_instance):
+ host_vars = {'Events': '', 'Persistent': False}
+ try:
+ kwargs = {'InstanceIds': [instance_id]}
+ host_vars['Events'] = connection.describe_instance_status(**kwargs)['InstanceStatuses'][0].get('Events', '')
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ if not self.get_option('strict_permissions'):
+ pass
+ else:
+ raise AnsibleError("Failed to describe instance status: %s" % to_native(e))
+ if spot_instance:
+ try:
+ kwargs = {'SpotInstanceRequestIds': [spot_instance]}
+ host_vars['Persistent'] = bool(
+ connection.describe_spot_instance_requests(**kwargs)['SpotInstanceRequests'][0].get('Type') == 'persistent'
+ )
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ if not self.get_option('strict_permissions'):
+ pass
+ else:
+ raise AnsibleError("Failed to describe spot instance requests: %s" % to_native(e))
+ return host_vars
+
+ def _get_tag_hostname(self, preference, instance):
+ tag_hostnames = preference.split('tag:', 1)[1]
+ if ',' in tag_hostnames:
+ tag_hostnames = tag_hostnames.split(',')
+ else:
+ tag_hostnames = [tag_hostnames]
+ tags = boto3_tag_list_to_ansible_dict(instance.get('Tags', []))
+ for v in tag_hostnames:
+ if '=' in v:
+ tag_name, tag_value = v.split('=')
+ if tags.get(tag_name) == tag_value:
+ return to_text(tag_name) + "_" + to_text(tag_value)
+ else:
+ tag_value = tags.get(v)
+ if tag_value:
+ return to_text(tag_value)
+ return None
+
+ def _get_hostname(self, instance, hostnames):
+ '''
+ :param instance: an instance dict returned by boto3 ec2 describe_instances()
+ :param hostnames: a list of hostname destination variables in order of preference
+ :return the preferred identifer for the host
+ '''
+ if not hostnames:
+ hostnames = ['dns-name', 'private-dns-name']
+
+ hostname = None
+ for preference in hostnames:
+ if 'tag' in preference:
+ if not preference.startswith('tag:'):
+ raise AnsibleError("To name a host by tags name_value, use 'tag:name=value'.")
+ hostname = self._get_tag_hostname(preference, instance)
+ else:
+ hostname = self._get_boto_attr_chain(preference, instance)
+ if hostname:
+ break
+ if hostname:
+ if ':' in to_text(hostname):
+ return self._sanitize_group_name((to_text(hostname)))
+ else:
+ return to_text(hostname)
+
+ def _query(self, regions, filters, strict_permissions):
+ '''
+ :param regions: a list of regions to query
+ :param filters: a list of boto3 filter dictionaries
+ :param hostnames: a list of hostname destination variables in order of preference
+ :param strict_permissions: a boolean determining whether to fail or ignore 403 error codes
+ '''
+ return {'aws_ec2': self._get_instances_by_region(regions, filters, strict_permissions)}
+
+ def _populate(self, groups, hostnames):
+ for group in groups:
+ group = self.inventory.add_group(group)
+ self._add_hosts(hosts=groups[group], group=group, hostnames=hostnames)
+ self.inventory.add_child('all', group)
+
+ def _add_hosts(self, hosts, group, hostnames):
+ '''
+ :param hosts: a list of hosts to be added to a group
+ :param group: the name of the group to which the hosts belong
+ :param hostnames: a list of hostname destination variables in order of preference
+ '''
+ for host in hosts:
+ hostname = self._get_hostname(host, hostnames)
+
+ host = camel_dict_to_snake_dict(host, ignore_list=['Tags'])
+ host['tags'] = boto3_tag_list_to_ansible_dict(host.get('tags', []))
+
+ # Allow easier grouping by region
+ host['placement']['region'] = host['placement']['availability_zone'][:-1]
+
+ if not hostname:
+ continue
+ self.inventory.add_host(hostname, group=group)
+ for hostvar, hostval in host.items():
+ self.inventory.set_variable(hostname, hostvar, hostval)
+
+ # Use constructed if applicable
+
+ strict = self.get_option('strict')
+
+ # Composed variables
+ self._set_composite_vars(self.get_option('compose'), host, hostname, strict=strict)
+
+ # Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group
+ self._add_host_to_composed_groups(self.get_option('groups'), host, hostname, strict=strict)
+
+ # Create groups based on variable values and add the corresponding hosts to it
+ self._add_host_to_keyed_groups(self.get_option('keyed_groups'), host, hostname, strict=strict)
+
+ def _set_credentials(self):
+ '''
+ :param config_data: contents of the inventory config file
+ '''
+
+ self.boto_profile = self.get_option('aws_profile')
+ self.aws_access_key_id = self.get_option('aws_access_key')
+ self.aws_secret_access_key = self.get_option('aws_secret_key')
+ self.aws_security_token = self.get_option('aws_security_token')
+ self.iam_role_arn = self.get_option('iam_role_arn')
+
+ if not self.boto_profile and not (self.aws_access_key_id and self.aws_secret_access_key):
+ session = botocore.session.get_session()
+ try:
+ credentials = session.get_credentials().get_frozen_credentials()
+ except AttributeError:
+ pass
+ else:
+ self.aws_access_key_id = credentials.access_key
+ self.aws_secret_access_key = credentials.secret_key
+ self.aws_security_token = credentials.token
+
+ if not self.boto_profile and not (self.aws_access_key_id and self.aws_secret_access_key):
+ raise AnsibleError("Insufficient boto credentials found. Please provide them in your "
+ "inventory configuration file or set them as environment variables.")
+
+ def verify_file(self, path):
+ '''
+ :param loader: an ansible.parsing.dataloader.DataLoader object
+ :param path: the path to the inventory config file
+ :return the contents of the config file
+ '''
+ if super(InventoryModule, self).verify_file(path):
+ if path.endswith(('aws_ec2.yml', 'aws_ec2.yaml')):
+ return True
+ display.debug("aws_ec2 inventory filename must end with 'aws_ec2.yml' or 'aws_ec2.yaml'")
+ return False
+
+ def parse(self, inventory, loader, path, cache=True):
+
+ super(InventoryModule, self).parse(inventory, loader, path)
+
+ self._read_config_data(path)
+
+ if self.get_option('use_contrib_script_compatible_sanitization'):
+ self._sanitize_group_name = self._legacy_script_compatible_group_sanitization
+
+ self._set_credentials()
+
+ # get user specifications
+ regions = self.get_option('regions')
+ filters = ansible_dict_to_boto3_filter_list(self.get_option('filters'))
+ hostnames = self.get_option('hostnames')
+ strict_permissions = self.get_option('strict_permissions')
+
+ cache_key = self.get_cache_key(path)
+ # false when refresh_cache or --flush-cache is used
+ if cache:
+ # get the user-specified directive
+ cache = self.get_option('cache')
+
+ # Generate inventory
+ cache_needs_update = False
+ if cache:
+ try:
+ results = self._cache[cache_key]
+ except KeyError:
+ # if cache expires or cache file doesn't exist
+ cache_needs_update = True
+
+ if not cache or cache_needs_update:
+ results = self._query(regions, filters, strict_permissions)
+
+ self._populate(results, hostnames)
+
+ # If the cache has expired/doesn't exist or if refresh_inventory/flush cache is used
+ # when the user is using caching, update the cached inventory
+ if cache_needs_update or (not cache and self.get_option('cache')):
+ self._cache[cache_key] = results
+
+ @staticmethod
+ def _legacy_script_compatible_group_sanitization(name):
+
+ # note that while this mirrors what the script used to do, it has many issues with unicode and usability in python
+ regex = re.compile(r"[^A-Za-z0-9\_\-]")
+
+ return regex.sub('_', name)
+
+
+def ansible_dict_to_boto3_filter_list(filters_dict):
+
+ """ Convert an Ansible dict of filters to list of dicts that boto3 can use
+ Args:
+ filters_dict (dict): Dict of AWS filters.
+ Basic Usage:
+ >>> filters = {'some-aws-id': 'i-01234567'}
+ >>> ansible_dict_to_boto3_filter_list(filters)
+ {
+ 'some-aws-id': 'i-01234567'
+ }
+ Returns:
+ List: List of AWS filters and their values
+ [
+ {
+ 'Name': 'some-aws-id',
+ 'Values': [
+ 'i-01234567',
+ ]
+ }
+ ]
+ """
+
+ filters_list = []
+ for k, v in filters_dict.items():
+ filter_dict = {'Name': k}
+ if isinstance(v, string_types):
+ filter_dict['Values'] = [v]
+ else:
+ filter_dict['Values'] = v
+
+ filters_list.append(filter_dict)
+
+ return filters_list
+
+
+def boto3_tag_list_to_ansible_dict(tags_list, tag_name_key_name=None, tag_value_key_name=None):
+
+ """ Convert a boto3 list of resource tags to a flat dict of key:value pairs
+ Args:
+ tags_list (list): List of dicts representing AWS tags.
+ tag_name_key_name (str): Value to use as the key for all tag keys (useful because boto3 doesn't always use "Key")
+ tag_value_key_name (str): Value to use as the key for all tag values (useful because boto3 doesn't always use "Value")
+ Basic Usage:
+ >>> tags_list = [{'Key': 'MyTagKey', 'Value': 'MyTagValue'}]
+ >>> boto3_tag_list_to_ansible_dict(tags_list)
+ [
+ {
+ 'Key': 'MyTagKey',
+ 'Value': 'MyTagValue'
+ }
+ ]
+ Returns:
+ Dict: Dict of key:value pairs representing AWS tags
+ {
+ 'MyTagKey': 'MyTagValue',
+ }
+ """
+
+ if tag_name_key_name and tag_value_key_name:
+ tag_candidates = {tag_name_key_name: tag_value_key_name}
+ else:
+ tag_candidates = {'key': 'value', 'Key': 'Value'}
+
+ if not tags_list:
+ return {}
+ for k, v in tag_candidates.items():
+ if k in tags_list[0] and v in tags_list[0]:
+ return dict((tag[k], tag[v]) for tag in tags_list)
+ raise ValueError("Couldn't find tag key (candidates %s) in tag list %s" % (str(tag_candidates), str(tags_list)))
diff --git a/test/support/integration/plugins/inventory/docker_swarm.py b/test/support/integration/plugins/inventory/docker_swarm.py
new file mode 100644
index 00000000..d0a95ca0
--- /dev/null
+++ b/test/support/integration/plugins/inventory/docker_swarm.py
@@ -0,0 +1,351 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2018, Stefan Heitmueller <stefan.heitmueller@gmx.com>
+# Copyright (c) 2018 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+
+__metaclass__ = type
+
+DOCUMENTATION = '''
+ name: docker_swarm
+ plugin_type: inventory
+ version_added: '2.8'
+ author:
+ - Stefan Heitmüller (@morph027) <stefan.heitmueller@gmx.com>
+ short_description: Ansible dynamic inventory plugin for Docker swarm nodes.
+ requirements:
+ - python >= 2.7
+ - L(Docker SDK for Python,https://docker-py.readthedocs.io/en/stable/) >= 1.10.0
+ extends_documentation_fragment:
+ - constructed
+ description:
+ - Reads inventories from the Docker swarm API.
+ - Uses a YAML configuration file docker_swarm.[yml|yaml].
+ - "The plugin returns following groups of swarm nodes: I(all) - all hosts; I(workers) - all worker nodes;
+ I(managers) - all manager nodes; I(leader) - the swarm leader node;
+ I(nonleaders) - all nodes except the swarm leader."
+ options:
+ plugin:
+ description: The name of this plugin, it should always be set to C(docker_swarm) for this plugin to
+ recognize it as it's own.
+ type: str
+ required: true
+ choices: docker_swarm
+ docker_host:
+ description:
+ - Socket of a Docker swarm manager node (C(tcp), C(unix)).
+ - "Use C(unix://var/run/docker.sock) to connect via local socket."
+ type: str
+ required: true
+ aliases: [ docker_url ]
+ verbose_output:
+ description: Toggle to (not) include all available nodes metadata (e.g. C(Platform), C(Architecture), C(OS),
+ C(EngineVersion))
+ type: bool
+ default: yes
+ tls:
+ description: Connect using TLS without verifying the authenticity of the Docker host server.
+ type: bool
+ default: no
+ validate_certs:
+ description: Toggle if connecting using TLS with or without verifying the authenticity of the Docker
+ host server.
+ type: bool
+ default: no
+ aliases: [ tls_verify ]
+ client_key:
+ description: Path to the client's TLS key file.
+ type: path
+ aliases: [ tls_client_key, key_path ]
+ ca_cert:
+ description: Use a CA certificate when performing server verification by providing the path to a CA
+ certificate file.
+ type: path
+ aliases: [ tls_ca_cert, cacert_path ]
+ client_cert:
+ description: Path to the client's TLS certificate file.
+ type: path
+ aliases: [ tls_client_cert, cert_path ]
+ tls_hostname:
+ description: When verifying the authenticity of the Docker host server, provide the expected name of
+ the server.
+ type: str
+ ssl_version:
+ description: Provide a valid SSL version number. Default value determined by ssl.py module.
+ type: str
+ api_version:
+ description:
+ - The version of the Docker API running on the Docker Host.
+ - Defaults to the latest version of the API supported by docker-py.
+ type: str
+ aliases: [ docker_api_version ]
+ timeout:
+ description:
+ - The maximum amount of time in seconds to wait on a response from the API.
+ - If the value is not specified in the task, the value of environment variable C(DOCKER_TIMEOUT)
+ will be used instead. If the environment variable is not set, the default value will be used.
+ type: int
+ default: 60
+ aliases: [ time_out ]
+ include_host_uri:
+ description: Toggle to return the additional attribute C(ansible_host_uri) which contains the URI of the
+ swarm leader in format of C(tcp://172.16.0.1:2376). This value may be used without additional
+ modification as value of option I(docker_host) in Docker Swarm modules when connecting via API.
+ The port always defaults to C(2376).
+ type: bool
+ default: no
+ include_host_uri_port:
+ description: Override the detected port number included in I(ansible_host_uri)
+ type: int
+'''
+
+EXAMPLES = '''
+# Minimal example using local docker
+plugin: docker_swarm
+docker_host: unix://var/run/docker.sock
+
+# Minimal example using remote docker
+plugin: docker_swarm
+docker_host: tcp://my-docker-host:2375
+
+# Example using remote docker with unverified TLS
+plugin: docker_swarm
+docker_host: tcp://my-docker-host:2376
+tls: yes
+
+# Example using remote docker with verified TLS and client certificate verification
+plugin: docker_swarm
+docker_host: tcp://my-docker-host:2376
+validate_certs: yes
+ca_cert: /somewhere/ca.pem
+client_key: /somewhere/key.pem
+client_cert: /somewhere/cert.pem
+
+# Example using constructed features to create groups and set ansible_host
+plugin: docker_swarm
+docker_host: tcp://my-docker-host:2375
+strict: False
+keyed_groups:
+ # add e.g. x86_64 hosts to an arch_x86_64 group
+ - prefix: arch
+ key: 'Description.Platform.Architecture'
+ # add e.g. linux hosts to an os_linux group
+ - prefix: os
+ key: 'Description.Platform.OS'
+ # create a group per node label
+ # e.g. a node labeled w/ "production" ends up in group "label_production"
+ # hint: labels containing special characters will be converted to safe names
+ - key: 'Spec.Labels'
+ prefix: label
+'''
+
+from ansible.errors import AnsibleError
+from ansible.module_utils._text import to_native
+from ansible.module_utils.six.moves.urllib.parse import urlparse
+from ansible.plugins.inventory import BaseInventoryPlugin, Constructable
+from ansible.parsing.utils.addresses import parse_address
+
+try:
+ import docker
+ from docker.errors import TLSParameterError
+ from docker.tls import TLSConfig
+ HAS_DOCKER = True
+except ImportError:
+ HAS_DOCKER = False
+
+
+def update_tls_hostname(result):
+ if result['tls_hostname'] is None:
+ # get default machine name from the url
+ parsed_url = urlparse(result['docker_host'])
+ if ':' in parsed_url.netloc:
+ result['tls_hostname'] = parsed_url.netloc[:parsed_url.netloc.rindex(':')]
+ else:
+ result['tls_hostname'] = parsed_url
+
+
+def _get_tls_config(fail_function, **kwargs):
+ try:
+ tls_config = TLSConfig(**kwargs)
+ return tls_config
+ except TLSParameterError as exc:
+ fail_function("TLS config error: %s" % exc)
+
+
+def get_connect_params(auth, fail_function):
+ if auth['tls'] or auth['tls_verify']:
+ auth['docker_host'] = auth['docker_host'].replace('tcp://', 'https://')
+
+ if auth['tls_verify'] and auth['cert_path'] and auth['key_path']:
+ # TLS with certs and host verification
+ if auth['cacert_path']:
+ tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']),
+ ca_cert=auth['cacert_path'],
+ verify=True,
+ assert_hostname=auth['tls_hostname'],
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ else:
+ tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']),
+ verify=True,
+ assert_hostname=auth['tls_hostname'],
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ if auth['tls_verify'] and auth['cacert_path']:
+ # TLS with cacert only
+ tls_config = _get_tls_config(ca_cert=auth['cacert_path'],
+ assert_hostname=auth['tls_hostname'],
+ verify=True,
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ if auth['tls_verify']:
+ # TLS with verify and no certs
+ tls_config = _get_tls_config(verify=True,
+ assert_hostname=auth['tls_hostname'],
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ if auth['tls'] and auth['cert_path'] and auth['key_path']:
+ # TLS with certs and no host verification
+ tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']),
+ verify=False,
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ if auth['tls']:
+ # TLS with no certs and not host verification
+ tls_config = _get_tls_config(verify=False,
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ # No TLS
+ return dict(base_url=auth['docker_host'],
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+
+class InventoryModule(BaseInventoryPlugin, Constructable):
+ ''' Host inventory parser for ansible using Docker swarm as source. '''
+
+ NAME = 'docker_swarm'
+
+ def _fail(self, msg):
+ raise AnsibleError(msg)
+
+ def _populate(self):
+ raw_params = dict(
+ docker_host=self.get_option('docker_host'),
+ tls=self.get_option('tls'),
+ tls_verify=self.get_option('validate_certs'),
+ key_path=self.get_option('client_key'),
+ cacert_path=self.get_option('ca_cert'),
+ cert_path=self.get_option('client_cert'),
+ tls_hostname=self.get_option('tls_hostname'),
+ api_version=self.get_option('api_version'),
+ timeout=self.get_option('timeout'),
+ ssl_version=self.get_option('ssl_version'),
+ debug=None,
+ )
+ update_tls_hostname(raw_params)
+ connect_params = get_connect_params(raw_params, fail_function=self._fail)
+ self.client = docker.DockerClient(**connect_params)
+ self.inventory.add_group('all')
+ self.inventory.add_group('manager')
+ self.inventory.add_group('worker')
+ self.inventory.add_group('leader')
+ self.inventory.add_group('nonleaders')
+
+ if self.get_option('include_host_uri'):
+ if self.get_option('include_host_uri_port'):
+ host_uri_port = str(self.get_option('include_host_uri_port'))
+ elif self.get_option('tls') or self.get_option('validate_certs'):
+ host_uri_port = '2376'
+ else:
+ host_uri_port = '2375'
+
+ try:
+ self.nodes = self.client.nodes.list()
+ for self.node in self.nodes:
+ self.node_attrs = self.client.nodes.get(self.node.id).attrs
+ self.inventory.add_host(self.node_attrs['ID'])
+ self.inventory.add_host(self.node_attrs['ID'], group=self.node_attrs['Spec']['Role'])
+ self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host',
+ self.node_attrs['Status']['Addr'])
+ if self.get_option('include_host_uri'):
+ self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host_uri',
+ 'tcp://' + self.node_attrs['Status']['Addr'] + ':' + host_uri_port)
+ if self.get_option('verbose_output'):
+ self.inventory.set_variable(self.node_attrs['ID'], 'docker_swarm_node_attributes', self.node_attrs)
+ if 'ManagerStatus' in self.node_attrs:
+ if self.node_attrs['ManagerStatus'].get('Leader'):
+ # This is workaround of bug in Docker when in some cases the Leader IP is 0.0.0.0
+ # Check moby/moby#35437 for details
+ swarm_leader_ip = parse_address(self.node_attrs['ManagerStatus']['Addr'])[0] or \
+ self.node_attrs['Status']['Addr']
+ if self.get_option('include_host_uri'):
+ self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host_uri',
+ 'tcp://' + swarm_leader_ip + ':' + host_uri_port)
+ self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host', swarm_leader_ip)
+ self.inventory.add_host(self.node_attrs['ID'], group='leader')
+ else:
+ self.inventory.add_host(self.node_attrs['ID'], group='nonleaders')
+ else:
+ self.inventory.add_host(self.node_attrs['ID'], group='nonleaders')
+ # Use constructed if applicable
+ strict = self.get_option('strict')
+ # Composed variables
+ self._set_composite_vars(self.get_option('compose'),
+ self.node_attrs,
+ self.node_attrs['ID'],
+ strict=strict)
+ # Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group
+ self._add_host_to_composed_groups(self.get_option('groups'),
+ self.node_attrs,
+ self.node_attrs['ID'],
+ strict=strict)
+ # Create groups based on variable values and add the corresponding hosts to it
+ self._add_host_to_keyed_groups(self.get_option('keyed_groups'),
+ self.node_attrs,
+ self.node_attrs['ID'],
+ strict=strict)
+ except Exception as e:
+ raise AnsibleError('Unable to fetch hosts from Docker swarm API, this was the original exception: %s' %
+ to_native(e))
+
+ def verify_file(self, path):
+ """Return the possibly of a file being consumable by this plugin."""
+ return (
+ super(InventoryModule, self).verify_file(path) and
+ path.endswith((self.NAME + '.yaml', self.NAME + '.yml')))
+
+ def parse(self, inventory, loader, path, cache=True):
+ if not HAS_DOCKER:
+ raise AnsibleError('The Docker swarm dynamic inventory plugin requires the Docker SDK for Python: '
+ 'https://github.com/docker/docker-py.')
+ super(InventoryModule, self).parse(inventory, loader, path, cache)
+ self._read_config_data(path)
+ self._populate()
diff --git a/test/support/integration/plugins/inventory/foreman.py b/test/support/integration/plugins/inventory/foreman.py
new file mode 100644
index 00000000..43073f81
--- /dev/null
+++ b/test/support/integration/plugins/inventory/foreman.py
@@ -0,0 +1,295 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2016 Guido Günther <agx@sigxcpu.org>, Daniel Lobato Garcia <dlobatog@redhat.com>
+# Copyright (c) 2018 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+DOCUMENTATION = '''
+ name: foreman
+ plugin_type: inventory
+ short_description: foreman inventory source
+ version_added: "2.6"
+ requirements:
+ - requests >= 1.1
+ description:
+ - Get inventory hosts from the foreman service.
+ - "Uses a configuration file as an inventory source, it must end in ``.foreman.yml`` or ``.foreman.yaml`` and has a ``plugin: foreman`` entry."
+ extends_documentation_fragment:
+ - inventory_cache
+ - constructed
+ options:
+ plugin:
+ description: the name of this plugin, it should always be set to 'foreman' for this plugin to recognize it as it's own.
+ required: True
+ choices: ['foreman']
+ url:
+ description: url to foreman
+ default: 'http://localhost:3000'
+ env:
+ - name: FOREMAN_SERVER
+ version_added: "2.8"
+ user:
+ description: foreman authentication user
+ required: True
+ env:
+ - name: FOREMAN_USER
+ version_added: "2.8"
+ password:
+ description: foreman authentication password
+ required: True
+ env:
+ - name: FOREMAN_PASSWORD
+ version_added: "2.8"
+ validate_certs:
+ description: verify SSL certificate if using https
+ type: boolean
+ default: False
+ group_prefix:
+ description: prefix to apply to foreman groups
+ default: foreman_
+ vars_prefix:
+ description: prefix to apply to host variables, does not include facts nor params
+ default: foreman_
+ want_facts:
+ description: Toggle, if True the plugin will retrieve host facts from the server
+ type: boolean
+ default: False
+ want_params:
+ description: Toggle, if true the inventory will retrieve 'all_parameters' information as host vars
+ type: boolean
+ default: False
+ want_hostcollections:
+ description: Toggle, if true the plugin will create Ansible groups for host collections
+ type: boolean
+ default: False
+ version_added: '2.10'
+ want_ansible_ssh_host:
+ description: Toggle, if true the plugin will populate the ansible_ssh_host variable to explicitly specify the connection target
+ type: boolean
+ default: False
+ version_added: '2.10'
+
+'''
+
+EXAMPLES = '''
+# my.foreman.yml
+plugin: foreman
+url: http://localhost:2222
+user: ansible-tester
+password: secure
+validate_certs: False
+'''
+
+from distutils.version import LooseVersion
+
+from ansible.errors import AnsibleError
+from ansible.module_utils._text import to_bytes, to_native, to_text
+from ansible.module_utils.common._collections_compat import MutableMapping
+from ansible.plugins.inventory import BaseInventoryPlugin, Cacheable, to_safe_group_name, Constructable
+
+# 3rd party imports
+try:
+ import requests
+ if LooseVersion(requests.__version__) < LooseVersion('1.1.0'):
+ raise ImportError
+except ImportError:
+ raise AnsibleError('This script requires python-requests 1.1 as a minimum version')
+
+from requests.auth import HTTPBasicAuth
+
+
+class InventoryModule(BaseInventoryPlugin, Cacheable, Constructable):
+ ''' Host inventory parser for ansible using foreman as source. '''
+
+ NAME = 'foreman'
+
+ def __init__(self):
+
+ super(InventoryModule, self).__init__()
+
+ # from config
+ self.foreman_url = None
+
+ self.session = None
+ self.cache_key = None
+ self.use_cache = None
+
+ def verify_file(self, path):
+
+ valid = False
+ if super(InventoryModule, self).verify_file(path):
+ if path.endswith(('foreman.yaml', 'foreman.yml')):
+ valid = True
+ else:
+ self.display.vvv('Skipping due to inventory source not ending in "foreman.yaml" nor "foreman.yml"')
+ return valid
+
+ def _get_session(self):
+ if not self.session:
+ self.session = requests.session()
+ self.session.auth = HTTPBasicAuth(self.get_option('user'), to_bytes(self.get_option('password')))
+ self.session.verify = self.get_option('validate_certs')
+ return self.session
+
+ def _get_json(self, url, ignore_errors=None):
+
+ if not self.use_cache or url not in self._cache.get(self.cache_key, {}):
+
+ if self.cache_key not in self._cache:
+ self._cache[self.cache_key] = {url: ''}
+
+ results = []
+ s = self._get_session()
+ params = {'page': 1, 'per_page': 250}
+ while True:
+ ret = s.get(url, params=params)
+ if ignore_errors and ret.status_code in ignore_errors:
+ break
+ ret.raise_for_status()
+ json = ret.json()
+
+ # process results
+ # FIXME: This assumes 'return type' matches a specific query,
+ # it will break if we expand the queries and they dont have different types
+ if 'results' not in json:
+ # /hosts/:id dos not have a 'results' key
+ results = json
+ break
+ elif isinstance(json['results'], MutableMapping):
+ # /facts are returned as dict in 'results'
+ results = json['results']
+ break
+ else:
+ # /hosts 's 'results' is a list of all hosts, returned is paginated
+ results = results + json['results']
+
+ # check for end of paging
+ if len(results) >= json['subtotal']:
+ break
+ if len(json['results']) == 0:
+ self.display.warning("Did not make any progress during loop. expected %d got %d" % (json['subtotal'], len(results)))
+ break
+
+ # get next page
+ params['page'] += 1
+
+ self._cache[self.cache_key][url] = results
+
+ return self._cache[self.cache_key][url]
+
+ def _get_hosts(self):
+ return self._get_json("%s/api/v2/hosts" % self.foreman_url)
+
+ def _get_all_params_by_id(self, hid):
+ url = "%s/api/v2/hosts/%s" % (self.foreman_url, hid)
+ ret = self._get_json(url, [404])
+ if not ret or not isinstance(ret, MutableMapping) or not ret.get('all_parameters', False):
+ return {}
+ return ret.get('all_parameters')
+
+ def _get_facts_by_id(self, hid):
+ url = "%s/api/v2/hosts/%s/facts" % (self.foreman_url, hid)
+ return self._get_json(url)
+
+ def _get_host_data_by_id(self, hid):
+ url = "%s/api/v2/hosts/%s" % (self.foreman_url, hid)
+ return self._get_json(url)
+
+ def _get_facts(self, host):
+ """Fetch all host facts of the host"""
+
+ ret = self._get_facts_by_id(host['id'])
+ if len(ret.values()) == 0:
+ facts = {}
+ elif len(ret.values()) == 1:
+ facts = list(ret.values())[0]
+ else:
+ raise ValueError("More than one set of facts returned for '%s'" % host)
+ return facts
+
+ def _populate(self):
+
+ for host in self._get_hosts():
+
+ if host.get('name'):
+ host_name = self.inventory.add_host(host['name'])
+
+ # create directly mapped groups
+ group_name = host.get('hostgroup_title', host.get('hostgroup_name'))
+ if group_name:
+ group_name = to_safe_group_name('%s%s' % (self.get_option('group_prefix'), group_name.lower().replace(" ", "")))
+ group_name = self.inventory.add_group(group_name)
+ self.inventory.add_child(group_name, host_name)
+
+ # set host vars from host info
+ try:
+ for k, v in host.items():
+ if k not in ('name', 'hostgroup_title', 'hostgroup_name'):
+ try:
+ self.inventory.set_variable(host_name, self.get_option('vars_prefix') + k, v)
+ except ValueError as e:
+ self.display.warning("Could not set host info hostvar for %s, skipping %s: %s" % (host, k, to_text(e)))
+ except ValueError as e:
+ self.display.warning("Could not get host info for %s, skipping: %s" % (host_name, to_text(e)))
+
+ # set host vars from params
+ if self.get_option('want_params'):
+ for p in self._get_all_params_by_id(host['id']):
+ try:
+ self.inventory.set_variable(host_name, p['name'], p['value'])
+ except ValueError as e:
+ self.display.warning("Could not set hostvar %s to '%s' for the '%s' host, skipping: %s" %
+ (p['name'], to_native(p['value']), host, to_native(e)))
+
+ # set host vars from facts
+ if self.get_option('want_facts'):
+ self.inventory.set_variable(host_name, 'foreman_facts', self._get_facts(host))
+
+ # create group for host collections
+ if self.get_option('want_hostcollections'):
+ host_data = self._get_host_data_by_id(host['id'])
+ hostcollections = host_data.get('host_collections')
+ if hostcollections:
+ # Create Ansible groups for host collections
+ for hostcollection in hostcollections:
+ try:
+ hostcollection_group = to_safe_group_name('%shostcollection_%s' % (self.get_option('group_prefix'),
+ hostcollection['name'].lower().replace(" ", "")))
+ hostcollection_group = self.inventory.add_group(hostcollection_group)
+ self.inventory.add_child(hostcollection_group, host_name)
+ except ValueError as e:
+ self.display.warning("Could not create groups for host collections for %s, skipping: %s" % (host_name, to_text(e)))
+
+ # put ansible_ssh_host as hostvar
+ if self.get_option('want_ansible_ssh_host'):
+ for key in ('ip', 'ipv4', 'ipv6'):
+ if host.get(key):
+ try:
+ self.inventory.set_variable(host_name, 'ansible_ssh_host', host[key])
+ break
+ except ValueError as e:
+ self.display.warning("Could not set hostvar ansible_ssh_host to '%s' for the '%s' host, skipping: %s" %
+ (host[key], host_name, to_text(e)))
+
+ strict = self.get_option('strict')
+
+ hostvars = self.inventory.get_host(host_name).get_vars()
+ self._set_composite_vars(self.get_option('compose'), hostvars, host_name, strict)
+ self._add_host_to_composed_groups(self.get_option('groups'), hostvars, host_name, strict)
+ self._add_host_to_keyed_groups(self.get_option('keyed_groups'), hostvars, host_name, strict)
+
+ def parse(self, inventory, loader, path, cache=True):
+
+ super(InventoryModule, self).parse(inventory, loader, path)
+
+ # read config from file, this sets 'options'
+ self._read_config_data(path)
+
+ # get connection host
+ self.foreman_url = self.get_option('url')
+ self.cache_key = self.get_cache_key(path)
+ self.use_cache = cache and self.get_option('cache')
+
+ # actually populate inventory
+ self._populate()
diff --git a/test/support/integration/plugins/lookup/rabbitmq.py b/test/support/integration/plugins/lookup/rabbitmq.py
new file mode 100644
index 00000000..7c2745f4
--- /dev/null
+++ b/test/support/integration/plugins/lookup/rabbitmq.py
@@ -0,0 +1,190 @@
+# (c) 2018, John Imison <john+github@imison.net>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+DOCUMENTATION = """
+ lookup: rabbitmq
+ author: John Imison <@Im0>
+ version_added: "2.8"
+ short_description: Retrieve messages from an AMQP/AMQPS RabbitMQ queue.
+ description:
+ - This lookup uses a basic get to retrieve all, or a limited number C(count), messages from a RabbitMQ queue.
+ options:
+ url:
+ description:
+ - An URI connection string to connect to the AMQP/AMQPS RabbitMQ server.
+ - For more information refer to the URI spec U(https://www.rabbitmq.com/uri-spec.html).
+ required: True
+ queue:
+ description:
+ - The queue to get messages from.
+ required: True
+ count:
+ description:
+ - How many messages to collect from the queue.
+ - If not set, defaults to retrieving all the messages from the queue.
+ requirements:
+ - The python pika package U(https://pypi.org/project/pika/).
+ notes:
+ - This lookup implements BlockingChannel.basic_get to get messages from a RabbitMQ server.
+ - After retrieving a message from the server, receipt of the message is acknowledged and the message on the server is deleted.
+ - Pika is a pure-Python implementation of the AMQP 0-9-1 protocol that tries to stay fairly independent of the underlying network support library.
+ - More information about pika can be found at U(https://pika.readthedocs.io/en/stable/).
+ - This plugin is tested against RabbitMQ. Other AMQP 0.9.1 protocol based servers may work but not tested/guaranteed.
+ - Assigning the return messages to a variable under C(vars) may result in unexpected results as the lookup is evaluated every time the
+ variable is referenced.
+ - Currently this plugin only handles text based messages from a queue. Unexpected results may occur when retrieving binary data.
+"""
+
+
+EXAMPLES = """
+- name: Get all messages off a queue
+ debug:
+ msg: "{{ lookup('rabbitmq', url='amqp://guest:guest@192.168.0.10:5672/%2F', queue='hello') }}"
+
+
+# If you are intending on using the returned messages as a variable in more than
+# one task (eg. debug, template), it is recommended to set_fact.
+
+- name: Get 2 messages off a queue and set a fact for re-use
+ set_fact:
+ messages: "{{ lookup('rabbitmq', url='amqp://guest:guest@192.168.0.10:5672/%2F', queue='hello', count=2) }}"
+
+- name: Dump out contents of the messages
+ debug:
+ var: messages
+
+"""
+
+RETURN = """
+ _list:
+ description:
+ - A list of dictionaries with keys and value from the queue.
+ type: list
+ contains:
+ content_type:
+ description: The content_type on the message in the queue.
+ type: str
+ delivery_mode:
+ description: The delivery_mode on the message in the queue.
+ type: str
+ delivery_tag:
+ description: The delivery_tag on the message in the queue.
+ type: str
+ exchange:
+ description: The exchange the message came from.
+ type: str
+ message_count:
+ description: The message_count for the message on the queue.
+ type: str
+ msg:
+ description: The content of the message.
+ type: str
+ redelivered:
+ description: The redelivered flag. True if the message has been delivered before.
+ type: bool
+ routing_key:
+ description: The routing_key on the message in the queue.
+ type: str
+ headers:
+ description: The headers for the message returned from the queue.
+ type: dict
+ json:
+ description: If application/json is specified in content_type, json will be loaded into variables.
+ type: dict
+
+"""
+
+import json
+
+from ansible.errors import AnsibleError, AnsibleParserError
+from ansible.plugins.lookup import LookupBase
+from ansible.module_utils._text import to_native, to_text
+from ansible.utils.display import Display
+
+try:
+ import pika
+ from pika import spec
+ HAS_PIKA = True
+except ImportError:
+ HAS_PIKA = False
+
+display = Display()
+
+
+class LookupModule(LookupBase):
+
+ def run(self, terms, variables=None, url=None, queue=None, count=None):
+ if not HAS_PIKA:
+ raise AnsibleError('pika python package is required for rabbitmq lookup.')
+ if not url:
+ raise AnsibleError('URL is required for rabbitmq lookup.')
+ if not queue:
+ raise AnsibleError('Queue is required for rabbitmq lookup.')
+
+ display.vvv(u"terms:%s : variables:%s url:%s queue:%s count:%s" % (terms, variables, url, queue, count))
+
+ try:
+ parameters = pika.URLParameters(url)
+ except Exception as e:
+ raise AnsibleError("URL malformed: %s" % to_native(e))
+
+ try:
+ connection = pika.BlockingConnection(parameters)
+ except Exception as e:
+ raise AnsibleError("Connection issue: %s" % to_native(e))
+
+ try:
+ conn_channel = connection.channel()
+ except pika.exceptions.AMQPChannelError as e:
+ try:
+ connection.close()
+ except pika.exceptions.AMQPConnectionError as ie:
+ raise AnsibleError("Channel and connection closing issues: %s / %s" % to_native(e), to_native(ie))
+ raise AnsibleError("Channel issue: %s" % to_native(e))
+
+ ret = []
+ idx = 0
+
+ while True:
+ method_frame, properties, body = conn_channel.basic_get(queue=queue)
+ if method_frame:
+ display.vvv(u"%s, %s, %s " % (method_frame, properties, to_text(body)))
+
+ # TODO: In the future consider checking content_type and handle text/binary data differently.
+ msg_details = dict({
+ 'msg': to_text(body),
+ 'message_count': method_frame.message_count,
+ 'routing_key': method_frame.routing_key,
+ 'delivery_tag': method_frame.delivery_tag,
+ 'redelivered': method_frame.redelivered,
+ 'exchange': method_frame.exchange,
+ 'delivery_mode': properties.delivery_mode,
+ 'content_type': properties.content_type,
+ 'headers': properties.headers
+ })
+ if properties.content_type == 'application/json':
+ try:
+ msg_details['json'] = json.loads(msg_details['msg'])
+ except ValueError as e:
+ raise AnsibleError("Unable to decode JSON for message %s: %s" % (method_frame.delivery_tag, to_native(e)))
+
+ ret.append(msg_details)
+ conn_channel.basic_ack(method_frame.delivery_tag)
+ idx += 1
+ if method_frame.message_count == 0 or idx == count:
+ break
+ # If we didn't get a method_frame, exit.
+ else:
+ break
+
+ if connection.is_closed:
+ return [ret]
+ else:
+ try:
+ connection.close()
+ except pika.exceptions.AMQPConnectionError:
+ pass
+ return [ret]
diff --git a/test/support/integration/plugins/module_utils/aws/__init__.py b/test/support/integration/plugins/module_utils/aws/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/aws/__init__.py
diff --git a/test/support/integration/plugins/module_utils/aws/core.py b/test/support/integration/plugins/module_utils/aws/core.py
new file mode 100644
index 00000000..c4527b6d
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/aws/core.py
@@ -0,0 +1,335 @@
+#
+# Copyright 2017 Michael De La Rue | Ansible
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+"""This module adds shared support for generic Amazon AWS modules
+
+**This code is not yet ready for use in user modules. As of 2017**
+**and through to 2018, the interface is likely to change**
+**aggressively as the exact correct interface for ansible AWS modules**
+**is identified. In particular, until this notice goes away or is**
+**changed, methods may disappear from the interface. Please don't**
+**publish modules using this except directly to the main Ansible**
+**development repository.**
+
+In order to use this module, include it as part of a custom
+module as shown below.
+
+ from ansible.module_utils.aws import AnsibleAWSModule
+ module = AnsibleAWSModule(argument_spec=dictionary, supports_check_mode=boolean
+ mutually_exclusive=list1, required_together=list2)
+
+The 'AnsibleAWSModule' module provides similar, but more restricted,
+interfaces to the normal Ansible module. It also includes the
+additional methods for connecting to AWS using the standard module arguments
+
+ m.resource('lambda') # - get an AWS connection as a boto3 resource.
+
+or
+
+ m.client('sts') # - get an AWS connection as a boto3 client.
+
+To make use of AWSRetry easier, it can now be wrapped around any call from a
+module-created client. To add retries to a client, create a client:
+
+ m.client('ec2', retry_decorator=AWSRetry.jittered_backoff(retries=10))
+
+Any calls from that client can be made to use the decorator passed at call-time
+using the `aws_retry` argument. By default, no retries are used.
+
+ ec2 = m.client('ec2', retry_decorator=AWSRetry.jittered_backoff(retries=10))
+ ec2.describe_instances(InstanceIds=['i-123456789'], aws_retry=True)
+
+The call will be retried the specified number of times, so the calling functions
+don't need to be wrapped in the backoff decorator.
+"""
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import re
+import logging
+import traceback
+from functools import wraps
+from distutils.version import LooseVersion
+
+try:
+ from cStringIO import StringIO
+except ImportError:
+ # Python 3
+ from io import StringIO
+
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+from ansible.module_utils._text import to_native
+from ansible.module_utils.ec2 import HAS_BOTO3, camel_dict_to_snake_dict, ec2_argument_spec, boto3_conn
+from ansible.module_utils.ec2 import get_aws_connection_info, get_aws_region
+
+# We will also export HAS_BOTO3 so end user modules can use it.
+__all__ = ('AnsibleAWSModule', 'HAS_BOTO3', 'is_boto3_error_code')
+
+
+class AnsibleAWSModule(object):
+ """An ansible module class for AWS modules
+
+ AnsibleAWSModule provides an a class for building modules which
+ connect to Amazon Web Services. The interface is currently more
+ restricted than the basic module class with the aim that later the
+ basic module class can be reduced. If you find that any key
+ feature is missing please contact the author/Ansible AWS team
+ (available on #ansible-aws on IRC) to request the additional
+ features needed.
+ """
+ default_settings = {
+ "default_args": True,
+ "check_boto3": True,
+ "auto_retry": True,
+ "module_class": AnsibleModule
+ }
+
+ def __init__(self, **kwargs):
+ local_settings = {}
+ for key in AnsibleAWSModule.default_settings:
+ try:
+ local_settings[key] = kwargs.pop(key)
+ except KeyError:
+ local_settings[key] = AnsibleAWSModule.default_settings[key]
+ self.settings = local_settings
+
+ if local_settings["default_args"]:
+ # ec2_argument_spec contains the region so we use that; there's a patch coming which
+ # will add it to aws_argument_spec so if that's accepted then later we should change
+ # over
+ argument_spec_full = ec2_argument_spec()
+ try:
+ argument_spec_full.update(kwargs["argument_spec"])
+ except (TypeError, NameError):
+ pass
+ kwargs["argument_spec"] = argument_spec_full
+
+ self._module = AnsibleAWSModule.default_settings["module_class"](**kwargs)
+
+ if local_settings["check_boto3"] and not HAS_BOTO3:
+ self._module.fail_json(
+ msg=missing_required_lib('botocore or boto3'))
+
+ self.check_mode = self._module.check_mode
+ self._diff = self._module._diff
+ self._name = self._module._name
+
+ self._botocore_endpoint_log_stream = StringIO()
+ self.logger = None
+ if self.params.get('debug_botocore_endpoint_logs'):
+ self.logger = logging.getLogger('botocore.endpoint')
+ self.logger.setLevel(logging.DEBUG)
+ self.logger.addHandler(logging.StreamHandler(self._botocore_endpoint_log_stream))
+
+ @property
+ def params(self):
+ return self._module.params
+
+ def _get_resource_action_list(self):
+ actions = []
+ for ln in self._botocore_endpoint_log_stream.getvalue().split('\n'):
+ ln = ln.strip()
+ if not ln:
+ continue
+ found_operational_request = re.search(r"OperationModel\(name=.*?\)", ln)
+ if found_operational_request:
+ operation_request = found_operational_request.group(0)[20:-1]
+ resource = re.search(r"https://.*?\.", ln).group(0)[8:-1]
+ actions.append("{0}:{1}".format(resource, operation_request))
+ return list(set(actions))
+
+ def exit_json(self, *args, **kwargs):
+ if self.params.get('debug_botocore_endpoint_logs'):
+ kwargs['resource_actions'] = self._get_resource_action_list()
+ return self._module.exit_json(*args, **kwargs)
+
+ def fail_json(self, *args, **kwargs):
+ if self.params.get('debug_botocore_endpoint_logs'):
+ kwargs['resource_actions'] = self._get_resource_action_list()
+ return self._module.fail_json(*args, **kwargs)
+
+ def debug(self, *args, **kwargs):
+ return self._module.debug(*args, **kwargs)
+
+ def warn(self, *args, **kwargs):
+ return self._module.warn(*args, **kwargs)
+
+ def deprecate(self, *args, **kwargs):
+ return self._module.deprecate(*args, **kwargs)
+
+ def boolean(self, *args, **kwargs):
+ return self._module.boolean(*args, **kwargs)
+
+ def md5(self, *args, **kwargs):
+ return self._module.md5(*args, **kwargs)
+
+ def client(self, service, retry_decorator=None):
+ region, ec2_url, aws_connect_kwargs = get_aws_connection_info(self, boto3=True)
+ conn = boto3_conn(self, conn_type='client', resource=service,
+ region=region, endpoint=ec2_url, **aws_connect_kwargs)
+ return conn if retry_decorator is None else _RetryingBotoClientWrapper(conn, retry_decorator)
+
+ def resource(self, service):
+ region, ec2_url, aws_connect_kwargs = get_aws_connection_info(self, boto3=True)
+ return boto3_conn(self, conn_type='resource', resource=service,
+ region=region, endpoint=ec2_url, **aws_connect_kwargs)
+
+ @property
+ def region(self, boto3=True):
+ return get_aws_region(self, boto3)
+
+ def fail_json_aws(self, exception, msg=None):
+ """call fail_json with processed exception
+
+ function for converting exceptions thrown by AWS SDK modules,
+ botocore, boto3 and boto, into nice error messages.
+ """
+ last_traceback = traceback.format_exc()
+
+ # to_native is trusted to handle exceptions that str() could
+ # convert to text.
+ try:
+ except_msg = to_native(exception.message)
+ except AttributeError:
+ except_msg = to_native(exception)
+
+ if msg is not None:
+ message = '{0}: {1}'.format(msg, except_msg)
+ else:
+ message = except_msg
+
+ try:
+ response = exception.response
+ except AttributeError:
+ response = None
+
+ failure = dict(
+ msg=message,
+ exception=last_traceback,
+ **self._gather_versions()
+ )
+
+ if response is not None:
+ failure.update(**camel_dict_to_snake_dict(response))
+
+ self.fail_json(**failure)
+
+ def _gather_versions(self):
+ """Gather AWS SDK (boto3 and botocore) dependency versions
+
+ Returns {'boto3_version': str, 'botocore_version': str}
+ Returns {} if neither are installed
+ """
+ if not HAS_BOTO3:
+ return {}
+ import boto3
+ import botocore
+ return dict(boto3_version=boto3.__version__,
+ botocore_version=botocore.__version__)
+
+ def boto3_at_least(self, desired):
+ """Check if the available boto3 version is greater than or equal to a desired version.
+
+ Usage:
+ if module.params.get('assign_ipv6_address') and not module.boto3_at_least('1.4.4'):
+ # conditionally fail on old boto3 versions if a specific feature is not supported
+ module.fail_json(msg="Boto3 can't deal with EC2 IPv6 addresses before version 1.4.4.")
+ """
+ existing = self._gather_versions()
+ return LooseVersion(existing['boto3_version']) >= LooseVersion(desired)
+
+ def botocore_at_least(self, desired):
+ """Check if the available botocore version is greater than or equal to a desired version.
+
+ Usage:
+ if not module.botocore_at_least('1.2.3'):
+ module.fail_json(msg='The Serverless Elastic Load Compute Service is not in botocore before v1.2.3')
+ if not module.botocore_at_least('1.5.3'):
+ module.warn('Botocore did not include waiters for Service X before 1.5.3. '
+ 'To wait until Service X resources are fully available, update botocore.')
+ """
+ existing = self._gather_versions()
+ return LooseVersion(existing['botocore_version']) >= LooseVersion(desired)
+
+
+class _RetryingBotoClientWrapper(object):
+ __never_wait = (
+ 'get_paginator', 'can_paginate',
+ 'get_waiter', 'generate_presigned_url',
+ )
+
+ def __init__(self, client, retry):
+ self.client = client
+ self.retry = retry
+
+ def _create_optional_retry_wrapper_function(self, unwrapped):
+ retrying_wrapper = self.retry(unwrapped)
+
+ @wraps(unwrapped)
+ def deciding_wrapper(aws_retry=False, *args, **kwargs):
+ if aws_retry:
+ return retrying_wrapper(*args, **kwargs)
+ else:
+ return unwrapped(*args, **kwargs)
+ return deciding_wrapper
+
+ def __getattr__(self, name):
+ unwrapped = getattr(self.client, name)
+ if name in self.__never_wait:
+ return unwrapped
+ elif callable(unwrapped):
+ wrapped = self._create_optional_retry_wrapper_function(unwrapped)
+ setattr(self, name, wrapped)
+ return wrapped
+ else:
+ return unwrapped
+
+
+def is_boto3_error_code(code, e=None):
+ """Check if the botocore exception is raised by a specific error code.
+
+ Returns ClientError if the error code matches, a dummy exception if it does not have an error code or does not match
+
+ Example:
+ try:
+ ec2.describe_instances(InstanceIds=['potato'])
+ except is_boto3_error_code('InvalidInstanceID.Malformed'):
+ # handle the error for that code case
+ except botocore.exceptions.ClientError as e:
+ # handle the generic error case for all other codes
+ """
+ from botocore.exceptions import ClientError
+ if e is None:
+ import sys
+ dummy, e, dummy = sys.exc_info()
+ if isinstance(e, ClientError) and e.response['Error']['Code'] == code:
+ return ClientError
+ return type('NeverEverRaisedException', (Exception,), {})
+
+
+def get_boto3_client_method_parameters(client, method_name, required=False):
+ op = client.meta.method_to_api_mapping.get(method_name)
+ input_shape = client._service_model.operation_model(op).input_shape
+ if not input_shape:
+ parameters = []
+ elif required:
+ parameters = list(input_shape.required_members)
+ else:
+ parameters = list(input_shape.members.keys())
+ return parameters
diff --git a/test/support/integration/plugins/module_utils/aws/iam.py b/test/support/integration/plugins/module_utils/aws/iam.py
new file mode 100644
index 00000000..f05999aa
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/aws/iam.py
@@ -0,0 +1,49 @@
+# Copyright (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import traceback
+
+try:
+ from botocore.exceptions import ClientError, NoCredentialsError
+except ImportError:
+ pass # caught by HAS_BOTO3
+
+from ansible.module_utils._text import to_native
+
+
+def get_aws_account_id(module):
+ """ Given AnsibleAWSModule instance, get the active AWS account ID
+
+ get_account_id tries too find out the account that we are working
+ on. It's not guaranteed that this will be easy so we try in
+ several different ways. Giving either IAM or STS privilages to
+ the account should be enough to permit this.
+ """
+ account_id = None
+ try:
+ sts_client = module.client('sts')
+ account_id = sts_client.get_caller_identity().get('Account')
+ # non-STS sessions may also get NoCredentialsError from this STS call, so
+ # we must catch that too and try the IAM version
+ except (ClientError, NoCredentialsError):
+ try:
+ iam_client = module.client('iam')
+ account_id = iam_client.get_user()['User']['Arn'].split(':')[4]
+ except ClientError as e:
+ if (e.response['Error']['Code'] == 'AccessDenied'):
+ except_msg = to_native(e)
+ # don't match on `arn:aws` because of China region `arn:aws-cn` and similar
+ account_id = except_msg.search(r"arn:\w+:iam::([0-9]{12,32}):\w+/").group(1)
+ if account_id is None:
+ module.fail_json_aws(e, msg="Could not get AWS account information")
+ except Exception as e:
+ module.fail_json(
+ msg="Failed to get AWS account information, Try allowing sts:GetCallerIdentity or iam:GetUser permissions.",
+ exception=traceback.format_exc()
+ )
+ if not account_id:
+ module.fail_json(msg="Failed while determining AWS account ID. Try allowing sts:GetCallerIdentity or iam:GetUser permissions.")
+ return to_native(account_id)
diff --git a/test/support/integration/plugins/module_utils/aws/s3.py b/test/support/integration/plugins/module_utils/aws/s3.py
new file mode 100644
index 00000000..2185869d
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/aws/s3.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2018 Red Hat, Inc.
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+try:
+ from botocore.exceptions import BotoCoreError, ClientError
+except ImportError:
+ pass # Handled by the calling module
+
+HAS_MD5 = True
+try:
+ from hashlib import md5
+except ImportError:
+ try:
+ from md5 import md5
+ except ImportError:
+ HAS_MD5 = False
+
+
+def calculate_etag(module, filename, etag, s3, bucket, obj, version=None):
+ if not HAS_MD5:
+ return None
+
+ if '-' in etag:
+ # Multi-part ETag; a hash of the hashes of each part.
+ parts = int(etag[1:-1].split('-')[1])
+ digests = []
+
+ s3_kwargs = dict(
+ Bucket=bucket,
+ Key=obj,
+ )
+ if version:
+ s3_kwargs['VersionId'] = version
+
+ with open(filename, 'rb') as f:
+ for part_num in range(1, parts + 1):
+ s3_kwargs['PartNumber'] = part_num
+ try:
+ head = s3.head_object(**s3_kwargs)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to get head object")
+ digests.append(md5(f.read(int(head['ContentLength']))))
+
+ digest_squared = md5(b''.join(m.digest() for m in digests))
+ return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests))
+ else: # Compute the MD5 sum normally
+ return '"{0}"'.format(module.md5(filename))
diff --git a/test/support/integration/plugins/module_utils/aws/waiters.py b/test/support/integration/plugins/module_utils/aws/waiters.py
new file mode 100644
index 00000000..25db598b
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/aws/waiters.py
@@ -0,0 +1,405 @@
+# Copyright: (c) 2018, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+try:
+ import botocore.waiter as core_waiter
+except ImportError:
+ pass # caught by HAS_BOTO3
+
+
+ec2_data = {
+ "version": 2,
+ "waiters": {
+ "InternetGatewayExists": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeInternetGateways",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "expected": True,
+ "argument": "length(InternetGateways) > `0`",
+ "state": "success"
+ },
+ {
+ "matcher": "error",
+ "expected": "InvalidInternetGatewayID.NotFound",
+ "state": "retry"
+ },
+ ]
+ },
+ "RouteTableExists": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeRouteTables",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "expected": True,
+ "argument": "length(RouteTables[]) > `0`",
+ "state": "success"
+ },
+ {
+ "matcher": "error",
+ "expected": "InvalidRouteTableID.NotFound",
+ "state": "retry"
+ },
+ ]
+ },
+ "SecurityGroupExists": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeSecurityGroups",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "expected": True,
+ "argument": "length(SecurityGroups[]) > `0`",
+ "state": "success"
+ },
+ {
+ "matcher": "error",
+ "expected": "InvalidGroup.NotFound",
+ "state": "retry"
+ },
+ ]
+ },
+ "SubnetExists": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeSubnets",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "expected": True,
+ "argument": "length(Subnets[]) > `0`",
+ "state": "success"
+ },
+ {
+ "matcher": "error",
+ "expected": "InvalidSubnetID.NotFound",
+ "state": "retry"
+ },
+ ]
+ },
+ "SubnetHasMapPublic": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeSubnets",
+ "acceptors": [
+ {
+ "matcher": "pathAll",
+ "expected": True,
+ "argument": "Subnets[].MapPublicIpOnLaunch",
+ "state": "success"
+ },
+ ]
+ },
+ "SubnetNoMapPublic": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeSubnets",
+ "acceptors": [
+ {
+ "matcher": "pathAll",
+ "expected": False,
+ "argument": "Subnets[].MapPublicIpOnLaunch",
+ "state": "success"
+ },
+ ]
+ },
+ "SubnetHasAssignIpv6": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeSubnets",
+ "acceptors": [
+ {
+ "matcher": "pathAll",
+ "expected": True,
+ "argument": "Subnets[].AssignIpv6AddressOnCreation",
+ "state": "success"
+ },
+ ]
+ },
+ "SubnetNoAssignIpv6": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeSubnets",
+ "acceptors": [
+ {
+ "matcher": "pathAll",
+ "expected": False,
+ "argument": "Subnets[].AssignIpv6AddressOnCreation",
+ "state": "success"
+ },
+ ]
+ },
+ "SubnetDeleted": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeSubnets",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "expected": True,
+ "argument": "length(Subnets[]) > `0`",
+ "state": "retry"
+ },
+ {
+ "matcher": "error",
+ "expected": "InvalidSubnetID.NotFound",
+ "state": "success"
+ },
+ ]
+ },
+ "VpnGatewayExists": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeVpnGateways",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "expected": True,
+ "argument": "length(VpnGateways[]) > `0`",
+ "state": "success"
+ },
+ {
+ "matcher": "error",
+ "expected": "InvalidVpnGatewayID.NotFound",
+ "state": "retry"
+ },
+ ]
+ },
+ "VpnGatewayDetached": {
+ "delay": 5,
+ "maxAttempts": 40,
+ "operation": "DescribeVpnGateways",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "expected": True,
+ "argument": "VpnGateways[0].State == 'available'",
+ "state": "success"
+ },
+ ]
+ },
+ }
+}
+
+
+waf_data = {
+ "version": 2,
+ "waiters": {
+ "ChangeTokenInSync": {
+ "delay": 20,
+ "maxAttempts": 60,
+ "operation": "GetChangeTokenStatus",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "expected": True,
+ "argument": "ChangeTokenStatus == 'INSYNC'",
+ "state": "success"
+ },
+ {
+ "matcher": "error",
+ "expected": "WAFInternalErrorException",
+ "state": "retry"
+ }
+ ]
+ }
+ }
+}
+
+eks_data = {
+ "version": 2,
+ "waiters": {
+ "ClusterActive": {
+ "delay": 20,
+ "maxAttempts": 60,
+ "operation": "DescribeCluster",
+ "acceptors": [
+ {
+ "state": "success",
+ "matcher": "path",
+ "argument": "cluster.status",
+ "expected": "ACTIVE"
+ },
+ {
+ "state": "retry",
+ "matcher": "error",
+ "expected": "ResourceNotFoundException"
+ }
+ ]
+ },
+ "ClusterDeleted": {
+ "delay": 20,
+ "maxAttempts": 60,
+ "operation": "DescribeCluster",
+ "acceptors": [
+ {
+ "state": "retry",
+ "matcher": "path",
+ "argument": "cluster.status != 'DELETED'",
+ "expected": True
+ },
+ {
+ "state": "success",
+ "matcher": "error",
+ "expected": "ResourceNotFoundException"
+ }
+ ]
+ }
+ }
+}
+
+
+rds_data = {
+ "version": 2,
+ "waiters": {
+ "DBInstanceStopped": {
+ "delay": 20,
+ "maxAttempts": 60,
+ "operation": "DescribeDBInstances",
+ "acceptors": [
+ {
+ "state": "success",
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "stopped"
+ },
+ ]
+ }
+ }
+}
+
+
+def ec2_model(name):
+ ec2_models = core_waiter.WaiterModel(waiter_config=ec2_data)
+ return ec2_models.get_waiter(name)
+
+
+def waf_model(name):
+ waf_models = core_waiter.WaiterModel(waiter_config=waf_data)
+ return waf_models.get_waiter(name)
+
+
+def eks_model(name):
+ eks_models = core_waiter.WaiterModel(waiter_config=eks_data)
+ return eks_models.get_waiter(name)
+
+
+def rds_model(name):
+ rds_models = core_waiter.WaiterModel(waiter_config=rds_data)
+ return rds_models.get_waiter(name)
+
+
+waiters_by_name = {
+ ('EC2', 'internet_gateway_exists'): lambda ec2: core_waiter.Waiter(
+ 'internet_gateway_exists',
+ ec2_model('InternetGatewayExists'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_internet_gateways
+ )),
+ ('EC2', 'route_table_exists'): lambda ec2: core_waiter.Waiter(
+ 'route_table_exists',
+ ec2_model('RouteTableExists'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_route_tables
+ )),
+ ('EC2', 'security_group_exists'): lambda ec2: core_waiter.Waiter(
+ 'security_group_exists',
+ ec2_model('SecurityGroupExists'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_security_groups
+ )),
+ ('EC2', 'subnet_exists'): lambda ec2: core_waiter.Waiter(
+ 'subnet_exists',
+ ec2_model('SubnetExists'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_subnets
+ )),
+ ('EC2', 'subnet_has_map_public'): lambda ec2: core_waiter.Waiter(
+ 'subnet_has_map_public',
+ ec2_model('SubnetHasMapPublic'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_subnets
+ )),
+ ('EC2', 'subnet_no_map_public'): lambda ec2: core_waiter.Waiter(
+ 'subnet_no_map_public',
+ ec2_model('SubnetNoMapPublic'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_subnets
+ )),
+ ('EC2', 'subnet_has_assign_ipv6'): lambda ec2: core_waiter.Waiter(
+ 'subnet_has_assign_ipv6',
+ ec2_model('SubnetHasAssignIpv6'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_subnets
+ )),
+ ('EC2', 'subnet_no_assign_ipv6'): lambda ec2: core_waiter.Waiter(
+ 'subnet_no_assign_ipv6',
+ ec2_model('SubnetNoAssignIpv6'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_subnets
+ )),
+ ('EC2', 'subnet_deleted'): lambda ec2: core_waiter.Waiter(
+ 'subnet_deleted',
+ ec2_model('SubnetDeleted'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_subnets
+ )),
+ ('EC2', 'vpn_gateway_exists'): lambda ec2: core_waiter.Waiter(
+ 'vpn_gateway_exists',
+ ec2_model('VpnGatewayExists'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_vpn_gateways
+ )),
+ ('EC2', 'vpn_gateway_detached'): lambda ec2: core_waiter.Waiter(
+ 'vpn_gateway_detached',
+ ec2_model('VpnGatewayDetached'),
+ core_waiter.NormalizedOperationMethod(
+ ec2.describe_vpn_gateways
+ )),
+ ('WAF', 'change_token_in_sync'): lambda waf: core_waiter.Waiter(
+ 'change_token_in_sync',
+ waf_model('ChangeTokenInSync'),
+ core_waiter.NormalizedOperationMethod(
+ waf.get_change_token_status
+ )),
+ ('WAFRegional', 'change_token_in_sync'): lambda waf: core_waiter.Waiter(
+ 'change_token_in_sync',
+ waf_model('ChangeTokenInSync'),
+ core_waiter.NormalizedOperationMethod(
+ waf.get_change_token_status
+ )),
+ ('EKS', 'cluster_active'): lambda eks: core_waiter.Waiter(
+ 'cluster_active',
+ eks_model('ClusterActive'),
+ core_waiter.NormalizedOperationMethod(
+ eks.describe_cluster
+ )),
+ ('EKS', 'cluster_deleted'): lambda eks: core_waiter.Waiter(
+ 'cluster_deleted',
+ eks_model('ClusterDeleted'),
+ core_waiter.NormalizedOperationMethod(
+ eks.describe_cluster
+ )),
+ ('RDS', 'db_instance_stopped'): lambda rds: core_waiter.Waiter(
+ 'db_instance_stopped',
+ rds_model('DBInstanceStopped'),
+ core_waiter.NormalizedOperationMethod(
+ rds.describe_db_instances
+ )),
+}
+
+
+def get_waiter(client, waiter_name):
+ try:
+ return waiters_by_name[(client.__class__.__name__, waiter_name)](client)
+ except KeyError:
+ raise NotImplementedError("Waiter {0} could not be found for client {1}. Available waiters: {2}".format(
+ waiter_name, type(client), ', '.join(repr(k) for k in waiters_by_name.keys())))
diff --git a/test/support/integration/plugins/module_utils/azure_rm_common.py b/test/support/integration/plugins/module_utils/azure_rm_common.py
new file mode 100644
index 00000000..a7b55e97
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/azure_rm_common.py
@@ -0,0 +1,1473 @@
+# Copyright (c) 2016 Matt Davis, <mdavis@ansible.com>
+# Chris Houseknecht, <house@redhat.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+import os
+import re
+import types
+import copy
+import inspect
+import traceback
+import json
+
+from os.path import expanduser
+
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+try:
+ from ansible.module_utils.ansible_release import __version__ as ANSIBLE_VERSION
+except Exception:
+ ANSIBLE_VERSION = 'unknown'
+from ansible.module_utils.six.moves import configparser
+import ansible.module_utils.six.moves.urllib.parse as urlparse
+
+AZURE_COMMON_ARGS = dict(
+ auth_source=dict(
+ type='str',
+ choices=['auto', 'cli', 'env', 'credential_file', 'msi']
+ ),
+ profile=dict(type='str'),
+ subscription_id=dict(type='str'),
+ client_id=dict(type='str', no_log=True),
+ secret=dict(type='str', no_log=True),
+ tenant=dict(type='str', no_log=True),
+ ad_user=dict(type='str', no_log=True),
+ password=dict(type='str', no_log=True),
+ cloud_environment=dict(type='str', default='AzureCloud'),
+ cert_validation_mode=dict(type='str', choices=['validate', 'ignore']),
+ api_profile=dict(type='str', default='latest'),
+ adfs_authority_url=dict(type='str', default=None)
+)
+
+AZURE_CREDENTIAL_ENV_MAPPING = dict(
+ profile='AZURE_PROFILE',
+ subscription_id='AZURE_SUBSCRIPTION_ID',
+ client_id='AZURE_CLIENT_ID',
+ secret='AZURE_SECRET',
+ tenant='AZURE_TENANT',
+ ad_user='AZURE_AD_USER',
+ password='AZURE_PASSWORD',
+ cloud_environment='AZURE_CLOUD_ENVIRONMENT',
+ cert_validation_mode='AZURE_CERT_VALIDATION_MODE',
+ adfs_authority_url='AZURE_ADFS_AUTHORITY_URL'
+)
+
+
+class SDKProfile(object): # pylint: disable=too-few-public-methods
+
+ def __init__(self, default_api_version, profile=None):
+ """Constructor.
+
+ :param str default_api_version: Default API version if not overridden by a profile. Nullable.
+ :param profile: A dict operation group name to API version.
+ :type profile: dict[str, str]
+ """
+ self.profile = profile if profile is not None else {}
+ self.profile[None] = default_api_version
+
+ @property
+ def default_api_version(self):
+ return self.profile[None]
+
+
+# FUTURE: this should come from the SDK or an external location.
+# For now, we have to copy from azure-cli
+AZURE_API_PROFILES = {
+ 'latest': {
+ 'ContainerInstanceManagementClient': '2018-02-01-preview',
+ 'ComputeManagementClient': dict(
+ default_api_version='2018-10-01',
+ resource_skus='2018-10-01',
+ disks='2018-06-01',
+ snapshots='2018-10-01',
+ virtual_machine_run_commands='2018-10-01'
+ ),
+ 'NetworkManagementClient': '2018-08-01',
+ 'ResourceManagementClient': '2017-05-10',
+ 'StorageManagementClient': '2017-10-01',
+ 'WebSiteManagementClient': '2018-02-01',
+ 'PostgreSQLManagementClient': '2017-12-01',
+ 'MySQLManagementClient': '2017-12-01',
+ 'MariaDBManagementClient': '2019-03-01',
+ 'ManagementLockClient': '2016-09-01'
+ },
+ '2019-03-01-hybrid': {
+ 'StorageManagementClient': '2017-10-01',
+ 'NetworkManagementClient': '2017-10-01',
+ 'ComputeManagementClient': SDKProfile('2017-12-01', {
+ 'resource_skus': '2017-09-01',
+ 'disks': '2017-03-30',
+ 'snapshots': '2017-03-30'
+ }),
+ 'ManagementLinkClient': '2016-09-01',
+ 'ManagementLockClient': '2016-09-01',
+ 'PolicyClient': '2016-12-01',
+ 'ResourceManagementClient': '2018-05-01',
+ 'SubscriptionClient': '2016-06-01',
+ 'DnsManagementClient': '2016-04-01',
+ 'KeyVaultManagementClient': '2016-10-01',
+ 'AuthorizationManagementClient': SDKProfile('2015-07-01', {
+ 'classic_administrators': '2015-06-01',
+ 'policy_assignments': '2016-12-01',
+ 'policy_definitions': '2016-12-01'
+ }),
+ 'KeyVaultClient': '2016-10-01',
+ 'azure.multiapi.storage': '2017-11-09',
+ 'azure.multiapi.cosmosdb': '2017-04-17'
+ },
+ '2018-03-01-hybrid': {
+ 'StorageManagementClient': '2016-01-01',
+ 'NetworkManagementClient': '2017-10-01',
+ 'ComputeManagementClient': SDKProfile('2017-03-30'),
+ 'ManagementLinkClient': '2016-09-01',
+ 'ManagementLockClient': '2016-09-01',
+ 'PolicyClient': '2016-12-01',
+ 'ResourceManagementClient': '2018-02-01',
+ 'SubscriptionClient': '2016-06-01',
+ 'DnsManagementClient': '2016-04-01',
+ 'KeyVaultManagementClient': '2016-10-01',
+ 'AuthorizationManagementClient': SDKProfile('2015-07-01', {
+ 'classic_administrators': '2015-06-01'
+ }),
+ 'KeyVaultClient': '2016-10-01',
+ 'azure.multiapi.storage': '2017-04-17',
+ 'azure.multiapi.cosmosdb': '2017-04-17'
+ },
+ '2017-03-09-profile': {
+ 'StorageManagementClient': '2016-01-01',
+ 'NetworkManagementClient': '2015-06-15',
+ 'ComputeManagementClient': SDKProfile('2016-03-30'),
+ 'ManagementLinkClient': '2016-09-01',
+ 'ManagementLockClient': '2015-01-01',
+ 'PolicyClient': '2015-10-01-preview',
+ 'ResourceManagementClient': '2016-02-01',
+ 'SubscriptionClient': '2016-06-01',
+ 'DnsManagementClient': '2016-04-01',
+ 'KeyVaultManagementClient': '2016-10-01',
+ 'AuthorizationManagementClient': SDKProfile('2015-07-01', {
+ 'classic_administrators': '2015-06-01'
+ }),
+ 'KeyVaultClient': '2016-10-01',
+ 'azure.multiapi.storage': '2015-04-05'
+ }
+}
+
+AZURE_TAG_ARGS = dict(
+ tags=dict(type='dict'),
+ append_tags=dict(type='bool', default=True),
+)
+
+AZURE_COMMON_REQUIRED_IF = [
+ ('log_mode', 'file', ['log_path'])
+]
+
+ANSIBLE_USER_AGENT = 'Ansible/{0}'.format(ANSIBLE_VERSION)
+CLOUDSHELL_USER_AGENT_KEY = 'AZURE_HTTP_USER_AGENT'
+VSCODEEXT_USER_AGENT_KEY = 'VSCODEEXT_USER_AGENT'
+
+CIDR_PATTERN = re.compile(r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1"
+ r"[0-9]{2}|2[0-4][0-9]|25[0-5])(/([0-9]|[1-2][0-9]|3[0-2]))")
+
+AZURE_SUCCESS_STATE = "Succeeded"
+AZURE_FAILED_STATE = "Failed"
+
+HAS_AZURE = True
+HAS_AZURE_EXC = None
+HAS_AZURE_CLI_CORE = True
+HAS_AZURE_CLI_CORE_EXC = None
+
+HAS_MSRESTAZURE = True
+HAS_MSRESTAZURE_EXC = None
+
+try:
+ import importlib
+except ImportError:
+ # This passes the sanity import test, but does not provide a user friendly error message.
+ # Doing so would require catching Exception for all imports of Azure dependencies in modules and module_utils.
+ importlib = None
+
+try:
+ from packaging.version import Version
+ HAS_PACKAGING_VERSION = True
+ HAS_PACKAGING_VERSION_EXC = None
+except ImportError:
+ Version = None
+ HAS_PACKAGING_VERSION = False
+ HAS_PACKAGING_VERSION_EXC = traceback.format_exc()
+
+# NB: packaging issue sometimes cause msrestazure not to be installed, check it separately
+try:
+ from msrest.serialization import Serializer
+except ImportError:
+ HAS_MSRESTAZURE_EXC = traceback.format_exc()
+ HAS_MSRESTAZURE = False
+
+try:
+ from enum import Enum
+ from msrestazure.azure_active_directory import AADTokenCredentials
+ from msrestazure.azure_exceptions import CloudError
+ from msrestazure.azure_active_directory import MSIAuthentication
+ from msrestazure.tools import parse_resource_id, resource_id, is_valid_resource_id
+ from msrestazure import azure_cloud
+ from azure.common.credentials import ServicePrincipalCredentials, UserPassCredentials
+ from azure.mgmt.monitor.version import VERSION as monitor_client_version
+ from azure.mgmt.network.version import VERSION as network_client_version
+ from azure.mgmt.storage.version import VERSION as storage_client_version
+ from azure.mgmt.compute.version import VERSION as compute_client_version
+ from azure.mgmt.resource.version import VERSION as resource_client_version
+ from azure.mgmt.dns.version import VERSION as dns_client_version
+ from azure.mgmt.web.version import VERSION as web_client_version
+ from azure.mgmt.network import NetworkManagementClient
+ from azure.mgmt.resource.resources import ResourceManagementClient
+ from azure.mgmt.resource.subscriptions import SubscriptionClient
+ from azure.mgmt.storage import StorageManagementClient
+ from azure.mgmt.compute import ComputeManagementClient
+ from azure.mgmt.dns import DnsManagementClient
+ from azure.mgmt.monitor import MonitorManagementClient
+ from azure.mgmt.web import WebSiteManagementClient
+ from azure.mgmt.containerservice import ContainerServiceClient
+ from azure.mgmt.marketplaceordering import MarketplaceOrderingAgreements
+ from azure.mgmt.trafficmanager import TrafficManagerManagementClient
+ from azure.storage.cloudstorageaccount import CloudStorageAccount
+ from azure.storage.blob import PageBlobService, BlockBlobService
+ from adal.authentication_context import AuthenticationContext
+ from azure.mgmt.sql import SqlManagementClient
+ from azure.mgmt.servicebus import ServiceBusManagementClient
+ import azure.mgmt.servicebus.models as ServicebusModel
+ from azure.mgmt.rdbms.postgresql import PostgreSQLManagementClient
+ from azure.mgmt.rdbms.mysql import MySQLManagementClient
+ from azure.mgmt.rdbms.mariadb import MariaDBManagementClient
+ from azure.mgmt.containerregistry import ContainerRegistryManagementClient
+ from azure.mgmt.containerinstance import ContainerInstanceManagementClient
+ from azure.mgmt.loganalytics import LogAnalyticsManagementClient
+ import azure.mgmt.loganalytics.models as LogAnalyticsModels
+ from azure.mgmt.automation import AutomationClient
+ import azure.mgmt.automation.models as AutomationModel
+ from azure.mgmt.iothub import IotHubClient
+ from azure.mgmt.iothub import models as IoTHubModels
+ from msrest.service_client import ServiceClient
+ from msrestazure import AzureConfiguration
+ from msrest.authentication import Authentication
+ from azure.mgmt.resource.locks import ManagementLockClient
+except ImportError as exc:
+ Authentication = object
+ HAS_AZURE_EXC = traceback.format_exc()
+ HAS_AZURE = False
+
+from base64 import b64encode, b64decode
+from hashlib import sha256
+from hmac import HMAC
+from time import time
+
+try:
+ from urllib import (urlencode, quote_plus)
+except ImportError:
+ from urllib.parse import (urlencode, quote_plus)
+
+try:
+ from azure.cli.core.util import CLIError
+ from azure.common.credentials import get_azure_cli_credentials, get_cli_profile
+ from azure.common.cloud import get_cli_active_cloud
+except ImportError:
+ HAS_AZURE_CLI_CORE = False
+ HAS_AZURE_CLI_CORE_EXC = None
+ CLIError = Exception
+
+
+def azure_id_to_dict(id):
+ pieces = re.sub(r'^\/', '', id).split('/')
+ result = {}
+ index = 0
+ while index < len(pieces) - 1:
+ result[pieces[index]] = pieces[index + 1]
+ index += 1
+ return result
+
+
+def format_resource_id(val, subscription_id, namespace, types, resource_group):
+ return resource_id(name=val,
+ resource_group=resource_group,
+ namespace=namespace,
+ type=types,
+ subscription=subscription_id) if not is_valid_resource_id(val) else val
+
+
+def normalize_location_name(name):
+ return name.replace(' ', '').lower()
+
+
+# FUTURE: either get this from the requirements file (if we can be sure it's always available at runtime)
+# or generate the requirements files from this so we only have one source of truth to maintain...
+AZURE_PKG_VERSIONS = {
+ 'StorageManagementClient': {
+ 'package_name': 'storage',
+ 'expected_version': '3.1.0'
+ },
+ 'ComputeManagementClient': {
+ 'package_name': 'compute',
+ 'expected_version': '4.4.0'
+ },
+ 'ContainerInstanceManagementClient': {
+ 'package_name': 'containerinstance',
+ 'expected_version': '0.4.0'
+ },
+ 'NetworkManagementClient': {
+ 'package_name': 'network',
+ 'expected_version': '2.3.0'
+ },
+ 'ResourceManagementClient': {
+ 'package_name': 'resource',
+ 'expected_version': '2.1.0'
+ },
+ 'DnsManagementClient': {
+ 'package_name': 'dns',
+ 'expected_version': '2.1.0'
+ },
+ 'WebSiteManagementClient': {
+ 'package_name': 'web',
+ 'expected_version': '0.41.0'
+ },
+ 'TrafficManagerManagementClient': {
+ 'package_name': 'trafficmanager',
+ 'expected_version': '0.50.0'
+ },
+} if HAS_AZURE else {}
+
+
+AZURE_MIN_RELEASE = '2.0.0'
+
+
+class AzureRMModuleBase(object):
+ def __init__(self, derived_arg_spec, bypass_checks=False, no_log=False,
+ mutually_exclusive=None, required_together=None,
+ required_one_of=None, add_file_common_args=False, supports_check_mode=False,
+ required_if=None, supports_tags=True, facts_module=False, skip_exec=False):
+
+ merged_arg_spec = dict()
+ merged_arg_spec.update(AZURE_COMMON_ARGS)
+ if supports_tags:
+ merged_arg_spec.update(AZURE_TAG_ARGS)
+
+ if derived_arg_spec:
+ merged_arg_spec.update(derived_arg_spec)
+
+ merged_required_if = list(AZURE_COMMON_REQUIRED_IF)
+ if required_if:
+ merged_required_if += required_if
+
+ self.module = AnsibleModule(argument_spec=merged_arg_spec,
+ bypass_checks=bypass_checks,
+ no_log=no_log,
+ mutually_exclusive=mutually_exclusive,
+ required_together=required_together,
+ required_one_of=required_one_of,
+ add_file_common_args=add_file_common_args,
+ supports_check_mode=supports_check_mode,
+ required_if=merged_required_if)
+
+ if not HAS_PACKAGING_VERSION:
+ self.fail(msg=missing_required_lib('packaging'),
+ exception=HAS_PACKAGING_VERSION_EXC)
+
+ if not HAS_MSRESTAZURE:
+ self.fail(msg=missing_required_lib('msrestazure'),
+ exception=HAS_MSRESTAZURE_EXC)
+
+ if not HAS_AZURE:
+ self.fail(msg=missing_required_lib('ansible[azure] (azure >= {0})'.format(AZURE_MIN_RELEASE)),
+ exception=HAS_AZURE_EXC)
+
+ self._network_client = None
+ self._storage_client = None
+ self._resource_client = None
+ self._compute_client = None
+ self._dns_client = None
+ self._web_client = None
+ self._marketplace_client = None
+ self._sql_client = None
+ self._mysql_client = None
+ self._mariadb_client = None
+ self._postgresql_client = None
+ self._containerregistry_client = None
+ self._containerinstance_client = None
+ self._containerservice_client = None
+ self._managedcluster_client = None
+ self._traffic_manager_management_client = None
+ self._monitor_client = None
+ self._resource = None
+ self._log_analytics_client = None
+ self._servicebus_client = None
+ self._automation_client = None
+ self._IoThub_client = None
+ self._lock_client = None
+
+ self.check_mode = self.module.check_mode
+ self.api_profile = self.module.params.get('api_profile')
+ self.facts_module = facts_module
+ # self.debug = self.module.params.get('debug')
+
+ # delegate auth to AzureRMAuth class (shared with all plugin types)
+ self.azure_auth = AzureRMAuth(fail_impl=self.fail, **self.module.params)
+
+ # common parameter validation
+ if self.module.params.get('tags'):
+ self.validate_tags(self.module.params['tags'])
+
+ if not skip_exec:
+ res = self.exec_module(**self.module.params)
+ self.module.exit_json(**res)
+
+ def check_client_version(self, client_type):
+ # Ensure Azure modules are at least 2.0.0rc5.
+ package_version = AZURE_PKG_VERSIONS.get(client_type.__name__, None)
+ if package_version is not None:
+ client_name = package_version.get('package_name')
+ try:
+ client_module = importlib.import_module(client_type.__module__)
+ client_version = client_module.VERSION
+ except (RuntimeError, AttributeError):
+ # can't get at the module version for some reason, just fail silently...
+ return
+ expected_version = package_version.get('expected_version')
+ if Version(client_version) < Version(expected_version):
+ self.fail("Installed azure-mgmt-{0} client version is {1}. The minimum supported version is {2}. Try "
+ "`pip install ansible[azure]`".format(client_name, client_version, expected_version))
+ if Version(client_version) != Version(expected_version):
+ self.module.warn("Installed azure-mgmt-{0} client version is {1}. The expected version is {2}. Try "
+ "`pip install ansible[azure]`".format(client_name, client_version, expected_version))
+
+ def exec_module(self, **kwargs):
+ self.fail("Error: {0} failed to implement exec_module method.".format(self.__class__.__name__))
+
+ def fail(self, msg, **kwargs):
+ '''
+ Shortcut for calling module.fail()
+
+ :param msg: Error message text.
+ :param kwargs: Any key=value pairs
+ :return: None
+ '''
+ self.module.fail_json(msg=msg, **kwargs)
+
+ def deprecate(self, msg, version=None, collection_name=None):
+ self.module.deprecate(msg, version, collection_name=collection_name)
+
+ def log(self, msg, pretty_print=False):
+ if pretty_print:
+ self.module.debug(json.dumps(msg, indent=4, sort_keys=True))
+ else:
+ self.module.debug(msg)
+
+ def validate_tags(self, tags):
+ '''
+ Check if tags dictionary contains string:string pairs.
+
+ :param tags: dictionary of string:string pairs
+ :return: None
+ '''
+ if not self.facts_module:
+ if not isinstance(tags, dict):
+ self.fail("Tags must be a dictionary of string:string values.")
+ for key, value in tags.items():
+ if not isinstance(value, str):
+ self.fail("Tags values must be strings. Found {0}:{1}".format(str(key), str(value)))
+
+ def update_tags(self, tags):
+ '''
+ Call from the module to update metadata tags. Returns tuple
+ with bool indicating if there was a change and dict of new
+ tags to assign to the object.
+
+ :param tags: metadata tags from the object
+ :return: bool, dict
+ '''
+ tags = tags or dict()
+ new_tags = copy.copy(tags) if isinstance(tags, dict) else dict()
+ param_tags = self.module.params.get('tags') if isinstance(self.module.params.get('tags'), dict) else dict()
+ append_tags = self.module.params.get('append_tags') if self.module.params.get('append_tags') is not None else True
+ changed = False
+ # check add or update
+ for key, value in param_tags.items():
+ if not new_tags.get(key) or new_tags[key] != value:
+ changed = True
+ new_tags[key] = value
+ # check remove
+ if not append_tags:
+ for key, value in tags.items():
+ if not param_tags.get(key):
+ new_tags.pop(key)
+ changed = True
+ return changed, new_tags
+
+ def has_tags(self, obj_tags, tag_list):
+ '''
+ Used in fact modules to compare object tags to list of parameter tags. Return true if list of parameter tags
+ exists in object tags.
+
+ :param obj_tags: dictionary of tags from an Azure object.
+ :param tag_list: list of tag keys or tag key:value pairs
+ :return: bool
+ '''
+
+ if not obj_tags and tag_list:
+ return False
+
+ if not tag_list:
+ return True
+
+ matches = 0
+ result = False
+ for tag in tag_list:
+ tag_key = tag
+ tag_value = None
+ if ':' in tag:
+ tag_key, tag_value = tag.split(':')
+ if tag_value and obj_tags.get(tag_key) == tag_value:
+ matches += 1
+ elif not tag_value and obj_tags.get(tag_key):
+ matches += 1
+ if matches == len(tag_list):
+ result = True
+ return result
+
+ def get_resource_group(self, resource_group):
+ '''
+ Fetch a resource group.
+
+ :param resource_group: name of a resource group
+ :return: resource group object
+ '''
+ try:
+ return self.rm_client.resource_groups.get(resource_group)
+ except CloudError as cloud_error:
+ self.fail("Error retrieving resource group {0} - {1}".format(resource_group, cloud_error.message))
+ except Exception as exc:
+ self.fail("Error retrieving resource group {0} - {1}".format(resource_group, str(exc)))
+
+ def parse_resource_to_dict(self, resource):
+ '''
+ Return a dict of the give resource, which contains name and resource group.
+
+ :param resource: It can be a resource name, id or a dict contains name and resource group.
+ '''
+ resource_dict = parse_resource_id(resource) if not isinstance(resource, dict) else resource
+ resource_dict['resource_group'] = resource_dict.get('resource_group', self.resource_group)
+ resource_dict['subscription_id'] = resource_dict.get('subscription_id', self.subscription_id)
+ return resource_dict
+
+ def serialize_obj(self, obj, class_name, enum_modules=None):
+ '''
+ Return a JSON representation of an Azure object.
+
+ :param obj: Azure object
+ :param class_name: Name of the object's class
+ :param enum_modules: List of module names to build enum dependencies from.
+ :return: serialized result
+ '''
+ enum_modules = [] if enum_modules is None else enum_modules
+
+ dependencies = dict()
+ if enum_modules:
+ for module_name in enum_modules:
+ mod = importlib.import_module(module_name)
+ for mod_class_name, mod_class_obj in inspect.getmembers(mod, predicate=inspect.isclass):
+ dependencies[mod_class_name] = mod_class_obj
+ self.log("dependencies: ")
+ self.log(str(dependencies))
+ serializer = Serializer(classes=dependencies)
+ return serializer.body(obj, class_name, keep_readonly=True)
+
+ def get_poller_result(self, poller, wait=5):
+ '''
+ Consistent method of waiting on and retrieving results from Azure's long poller
+
+ :param poller Azure poller object
+ :return object resulting from the original request
+ '''
+ try:
+ delay = wait
+ while not poller.done():
+ self.log("Waiting for {0} sec".format(delay))
+ poller.wait(timeout=delay)
+ return poller.result()
+ except Exception as exc:
+ self.log(str(exc))
+ raise
+
+ def check_provisioning_state(self, azure_object, requested_state='present'):
+ '''
+ Check an Azure object's provisioning state. If something did not complete the provisioning
+ process, then we cannot operate on it.
+
+ :param azure_object An object such as a subnet, storageaccount, etc. Must have provisioning_state
+ and name attributes.
+ :return None
+ '''
+
+ if hasattr(azure_object, 'properties') and hasattr(azure_object.properties, 'provisioning_state') and \
+ hasattr(azure_object, 'name'):
+ # resource group object fits this model
+ if isinstance(azure_object.properties.provisioning_state, Enum):
+ if azure_object.properties.provisioning_state.value != AZURE_SUCCESS_STATE and \
+ requested_state != 'absent':
+ self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format(
+ azure_object.name, azure_object.properties.provisioning_state, AZURE_SUCCESS_STATE))
+ return
+ if azure_object.properties.provisioning_state != AZURE_SUCCESS_STATE and \
+ requested_state != 'absent':
+ self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format(
+ azure_object.name, azure_object.properties.provisioning_state, AZURE_SUCCESS_STATE))
+ return
+
+ if hasattr(azure_object, 'provisioning_state') or not hasattr(azure_object, 'name'):
+ if isinstance(azure_object.provisioning_state, Enum):
+ if azure_object.provisioning_state.value != AZURE_SUCCESS_STATE and requested_state != 'absent':
+ self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format(
+ azure_object.name, azure_object.provisioning_state, AZURE_SUCCESS_STATE))
+ return
+ if azure_object.provisioning_state != AZURE_SUCCESS_STATE and requested_state != 'absent':
+ self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format(
+ azure_object.name, azure_object.provisioning_state, AZURE_SUCCESS_STATE))
+
+ def get_blob_client(self, resource_group_name, storage_account_name, storage_blob_type='block'):
+ keys = dict()
+ try:
+ # Get keys from the storage account
+ self.log('Getting keys')
+ account_keys = self.storage_client.storage_accounts.list_keys(resource_group_name, storage_account_name)
+ except Exception as exc:
+ self.fail("Error getting keys for account {0} - {1}".format(storage_account_name, str(exc)))
+
+ try:
+ self.log('Create blob service')
+ if storage_blob_type == 'page':
+ return PageBlobService(endpoint_suffix=self._cloud_environment.suffixes.storage_endpoint,
+ account_name=storage_account_name,
+ account_key=account_keys.keys[0].value)
+ elif storage_blob_type == 'block':
+ return BlockBlobService(endpoint_suffix=self._cloud_environment.suffixes.storage_endpoint,
+ account_name=storage_account_name,
+ account_key=account_keys.keys[0].value)
+ else:
+ raise Exception("Invalid storage blob type defined.")
+ except Exception as exc:
+ self.fail("Error creating blob service client for storage account {0} - {1}".format(storage_account_name,
+ str(exc)))
+
+ def create_default_pip(self, resource_group, location, public_ip_name, allocation_method='Dynamic', sku=None):
+ '''
+ Create a default public IP address <public_ip_name> to associate with a network interface.
+ If a PIP address matching <public_ip_name> exists, return it. Otherwise, create one.
+
+ :param resource_group: name of an existing resource group
+ :param location: a valid azure location
+ :param public_ip_name: base name to assign the public IP address
+ :param allocation_method: one of 'Static' or 'Dynamic'
+ :param sku: sku
+ :return: PIP object
+ '''
+ pip = None
+
+ self.log("Starting create_default_pip {0}".format(public_ip_name))
+ self.log("Check to see if public IP {0} exists".format(public_ip_name))
+ try:
+ pip = self.network_client.public_ip_addresses.get(resource_group, public_ip_name)
+ except CloudError:
+ pass
+
+ if pip:
+ self.log("Public ip {0} found.".format(public_ip_name))
+ self.check_provisioning_state(pip)
+ return pip
+
+ params = self.network_models.PublicIPAddress(
+ location=location,
+ public_ip_allocation_method=allocation_method,
+ sku=sku
+ )
+ self.log('Creating default public IP {0}'.format(public_ip_name))
+ try:
+ poller = self.network_client.public_ip_addresses.create_or_update(resource_group, public_ip_name, params)
+ except Exception as exc:
+ self.fail("Error creating {0} - {1}".format(public_ip_name, str(exc)))
+
+ return self.get_poller_result(poller)
+
+ def create_default_securitygroup(self, resource_group, location, security_group_name, os_type, open_ports):
+ '''
+ Create a default security group <security_group_name> to associate with a network interface. If a security group matching
+ <security_group_name> exists, return it. Otherwise, create one.
+
+ :param resource_group: Resource group name
+ :param location: azure location name
+ :param security_group_name: base name to use for the security group
+ :param os_type: one of 'Windows' or 'Linux'. Determins any default rules added to the security group.
+ :param ssh_port: for os_type 'Linux' port used in rule allowing SSH access.
+ :param rdp_port: for os_type 'Windows' port used in rule allowing RDP access.
+ :return: security_group object
+ '''
+ group = None
+
+ self.log("Create security group {0}".format(security_group_name))
+ self.log("Check to see if security group {0} exists".format(security_group_name))
+ try:
+ group = self.network_client.network_security_groups.get(resource_group, security_group_name)
+ except CloudError:
+ pass
+
+ if group:
+ self.log("Security group {0} found.".format(security_group_name))
+ self.check_provisioning_state(group)
+ return group
+
+ parameters = self.network_models.NetworkSecurityGroup()
+ parameters.location = location
+
+ if not open_ports:
+ # Open default ports based on OS type
+ if os_type == 'Linux':
+ # add an inbound SSH rule
+ parameters.security_rules = [
+ self.network_models.SecurityRule(protocol='Tcp',
+ source_address_prefix='*',
+ destination_address_prefix='*',
+ access='Allow',
+ direction='Inbound',
+ description='Allow SSH Access',
+ source_port_range='*',
+ destination_port_range='22',
+ priority=100,
+ name='SSH')
+ ]
+ parameters.location = location
+ else:
+ # for windows add inbound RDP and WinRM rules
+ parameters.security_rules = [
+ self.network_models.SecurityRule(protocol='Tcp',
+ source_address_prefix='*',
+ destination_address_prefix='*',
+ access='Allow',
+ direction='Inbound',
+ description='Allow RDP port 3389',
+ source_port_range='*',
+ destination_port_range='3389',
+ priority=100,
+ name='RDP01'),
+ self.network_models.SecurityRule(protocol='Tcp',
+ source_address_prefix='*',
+ destination_address_prefix='*',
+ access='Allow',
+ direction='Inbound',
+ description='Allow WinRM HTTPS port 5986',
+ source_port_range='*',
+ destination_port_range='5986',
+ priority=101,
+ name='WinRM01'),
+ ]
+ else:
+ # Open custom ports
+ parameters.security_rules = []
+ priority = 100
+ for port in open_ports:
+ priority += 1
+ rule_name = "Rule_{0}".format(priority)
+ parameters.security_rules.append(
+ self.network_models.SecurityRule(protocol='Tcp',
+ source_address_prefix='*',
+ destination_address_prefix='*',
+ access='Allow',
+ direction='Inbound',
+ source_port_range='*',
+ destination_port_range=str(port),
+ priority=priority,
+ name=rule_name)
+ )
+
+ self.log('Creating default security group {0}'.format(security_group_name))
+ try:
+ poller = self.network_client.network_security_groups.create_or_update(resource_group,
+ security_group_name,
+ parameters)
+ except Exception as exc:
+ self.fail("Error creating default security rule {0} - {1}".format(security_group_name, str(exc)))
+
+ return self.get_poller_result(poller)
+
+ @staticmethod
+ def _validation_ignore_callback(session, global_config, local_config, **kwargs):
+ session.verify = False
+
+ def get_api_profile(self, client_type_name, api_profile_name):
+ profile_all_clients = AZURE_API_PROFILES.get(api_profile_name)
+
+ if not profile_all_clients:
+ raise KeyError("unknown Azure API profile: {0}".format(api_profile_name))
+
+ profile_raw = profile_all_clients.get(client_type_name, None)
+
+ if not profile_raw:
+ self.module.warn("Azure API profile {0} does not define an entry for {1}".format(api_profile_name, client_type_name))
+
+ if isinstance(profile_raw, dict):
+ if not profile_raw.get('default_api_version'):
+ raise KeyError("Azure API profile {0} does not define 'default_api_version'".format(api_profile_name))
+ return profile_raw
+
+ # wrap basic strings in a dict that just defines the default
+ return dict(default_api_version=profile_raw)
+
+ def get_mgmt_svc_client(self, client_type, base_url=None, api_version=None):
+ self.log('Getting management service client {0}'.format(client_type.__name__))
+ self.check_client_version(client_type)
+
+ client_argspec = inspect.getargspec(client_type.__init__)
+
+ if not base_url:
+ # most things are resource_manager, don't make everyone specify
+ base_url = self.azure_auth._cloud_environment.endpoints.resource_manager
+
+ client_kwargs = dict(credentials=self.azure_auth.azure_credentials, subscription_id=self.azure_auth.subscription_id, base_url=base_url)
+
+ api_profile_dict = {}
+
+ if self.api_profile:
+ api_profile_dict = self.get_api_profile(client_type.__name__, self.api_profile)
+
+ # unversioned clients won't accept profile; only send it if necessary
+ # clients without a version specified in the profile will use the default
+ if api_profile_dict and 'profile' in client_argspec.args:
+ client_kwargs['profile'] = api_profile_dict
+
+ # If the client doesn't accept api_version, it's unversioned.
+ # If it does, favor explicitly-specified api_version, fall back to api_profile
+ if 'api_version' in client_argspec.args:
+ profile_default_version = api_profile_dict.get('default_api_version', None)
+ if api_version or profile_default_version:
+ client_kwargs['api_version'] = api_version or profile_default_version
+ if 'profile' in client_kwargs:
+ # remove profile; only pass API version if specified
+ client_kwargs.pop('profile')
+
+ client = client_type(**client_kwargs)
+
+ # FUTURE: remove this once everything exposes models directly (eg, containerinstance)
+ try:
+ getattr(client, "models")
+ except AttributeError:
+ def _ansible_get_models(self, *arg, **kwarg):
+ return self._ansible_models
+
+ setattr(client, '_ansible_models', importlib.import_module(client_type.__module__).models)
+ client.models = types.MethodType(_ansible_get_models, client)
+
+ client.config = self.add_user_agent(client.config)
+
+ if self.azure_auth._cert_validation_mode == 'ignore':
+ client.config.session_configuration_callback = self._validation_ignore_callback
+
+ return client
+
+ def add_user_agent(self, config):
+ # Add user agent for Ansible
+ config.add_user_agent(ANSIBLE_USER_AGENT)
+ # Add user agent when running from Cloud Shell
+ if CLOUDSHELL_USER_AGENT_KEY in os.environ:
+ config.add_user_agent(os.environ[CLOUDSHELL_USER_AGENT_KEY])
+ # Add user agent when running from VSCode extension
+ if VSCODEEXT_USER_AGENT_KEY in os.environ:
+ config.add_user_agent(os.environ[VSCODEEXT_USER_AGENT_KEY])
+ return config
+
+ def generate_sas_token(self, **kwags):
+ base_url = kwags.get('base_url', None)
+ expiry = kwags.get('expiry', time() + 3600)
+ key = kwags.get('key', None)
+ policy = kwags.get('policy', None)
+ url = quote_plus(base_url)
+ ttl = int(expiry)
+ sign_key = '{0}\n{1}'.format(url, ttl)
+ signature = b64encode(HMAC(b64decode(key), sign_key.encode('utf-8'), sha256).digest())
+ result = {
+ 'sr': url,
+ 'sig': signature,
+ 'se': str(ttl),
+ }
+ if policy:
+ result['skn'] = policy
+ return 'SharedAccessSignature ' + urlencode(result)
+
+ def get_data_svc_client(self, **kwags):
+ url = kwags.get('base_url', None)
+ config = AzureConfiguration(base_url='https://{0}'.format(url))
+ config.credentials = AzureSASAuthentication(token=self.generate_sas_token(**kwags))
+ config = self.add_user_agent(config)
+ return ServiceClient(creds=config.credentials, config=config)
+
+ # passthru methods to AzureAuth instance for backcompat
+ @property
+ def credentials(self):
+ return self.azure_auth.credentials
+
+ @property
+ def _cloud_environment(self):
+ return self.azure_auth._cloud_environment
+
+ @property
+ def subscription_id(self):
+ return self.azure_auth.subscription_id
+
+ @property
+ def storage_client(self):
+ self.log('Getting storage client...')
+ if not self._storage_client:
+ self._storage_client = self.get_mgmt_svc_client(StorageManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2018-07-01')
+ return self._storage_client
+
+ @property
+ def storage_models(self):
+ return StorageManagementClient.models("2018-07-01")
+
+ @property
+ def network_client(self):
+ self.log('Getting network client')
+ if not self._network_client:
+ self._network_client = self.get_mgmt_svc_client(NetworkManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2019-06-01')
+ return self._network_client
+
+ @property
+ def network_models(self):
+ self.log("Getting network models...")
+ return NetworkManagementClient.models("2018-08-01")
+
+ @property
+ def rm_client(self):
+ self.log('Getting resource manager client')
+ if not self._resource_client:
+ self._resource_client = self.get_mgmt_svc_client(ResourceManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2017-05-10')
+ return self._resource_client
+
+ @property
+ def rm_models(self):
+ self.log("Getting resource manager models")
+ return ResourceManagementClient.models("2017-05-10")
+
+ @property
+ def compute_client(self):
+ self.log('Getting compute client')
+ if not self._compute_client:
+ self._compute_client = self.get_mgmt_svc_client(ComputeManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2019-07-01')
+ return self._compute_client
+
+ @property
+ def compute_models(self):
+ self.log("Getting compute models")
+ return ComputeManagementClient.models("2019-07-01")
+
+ @property
+ def dns_client(self):
+ self.log('Getting dns client')
+ if not self._dns_client:
+ self._dns_client = self.get_mgmt_svc_client(DnsManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2018-05-01')
+ return self._dns_client
+
+ @property
+ def dns_models(self):
+ self.log("Getting dns models...")
+ return DnsManagementClient.models('2018-05-01')
+
+ @property
+ def web_client(self):
+ self.log('Getting web client')
+ if not self._web_client:
+ self._web_client = self.get_mgmt_svc_client(WebSiteManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2018-02-01')
+ return self._web_client
+
+ @property
+ def containerservice_client(self):
+ self.log('Getting container service client')
+ if not self._containerservice_client:
+ self._containerservice_client = self.get_mgmt_svc_client(ContainerServiceClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2017-07-01')
+ return self._containerservice_client
+
+ @property
+ def managedcluster_models(self):
+ self.log("Getting container service models")
+ return ContainerServiceClient.models('2018-03-31')
+
+ @property
+ def managedcluster_client(self):
+ self.log('Getting container service client')
+ if not self._managedcluster_client:
+ self._managedcluster_client = self.get_mgmt_svc_client(ContainerServiceClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2018-03-31')
+ return self._managedcluster_client
+
+ @property
+ def sql_client(self):
+ self.log('Getting SQL client')
+ if not self._sql_client:
+ self._sql_client = self.get_mgmt_svc_client(SqlManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._sql_client
+
+ @property
+ def postgresql_client(self):
+ self.log('Getting PostgreSQL client')
+ if not self._postgresql_client:
+ self._postgresql_client = self.get_mgmt_svc_client(PostgreSQLManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._postgresql_client
+
+ @property
+ def mysql_client(self):
+ self.log('Getting MySQL client')
+ if not self._mysql_client:
+ self._mysql_client = self.get_mgmt_svc_client(MySQLManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._mysql_client
+
+ @property
+ def mariadb_client(self):
+ self.log('Getting MariaDB client')
+ if not self._mariadb_client:
+ self._mariadb_client = self.get_mgmt_svc_client(MariaDBManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._mariadb_client
+
+ @property
+ def sql_client(self):
+ self.log('Getting SQL client')
+ if not self._sql_client:
+ self._sql_client = self.get_mgmt_svc_client(SqlManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._sql_client
+
+ @property
+ def containerregistry_client(self):
+ self.log('Getting container registry mgmt client')
+ if not self._containerregistry_client:
+ self._containerregistry_client = self.get_mgmt_svc_client(ContainerRegistryManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2017-10-01')
+
+ return self._containerregistry_client
+
+ @property
+ def containerinstance_client(self):
+ self.log('Getting container instance mgmt client')
+ if not self._containerinstance_client:
+ self._containerinstance_client = self.get_mgmt_svc_client(ContainerInstanceManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2018-06-01')
+
+ return self._containerinstance_client
+
+ @property
+ def marketplace_client(self):
+ self.log('Getting marketplace agreement client')
+ if not self._marketplace_client:
+ self._marketplace_client = self.get_mgmt_svc_client(MarketplaceOrderingAgreements,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._marketplace_client
+
+ @property
+ def traffic_manager_management_client(self):
+ self.log('Getting traffic manager client')
+ if not self._traffic_manager_management_client:
+ self._traffic_manager_management_client = self.get_mgmt_svc_client(TrafficManagerManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._traffic_manager_management_client
+
+ @property
+ def monitor_client(self):
+ self.log('Getting monitor client')
+ if not self._monitor_client:
+ self._monitor_client = self.get_mgmt_svc_client(MonitorManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._monitor_client
+
+ @property
+ def log_analytics_client(self):
+ self.log('Getting log analytics client')
+ if not self._log_analytics_client:
+ self._log_analytics_client = self.get_mgmt_svc_client(LogAnalyticsManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._log_analytics_client
+
+ @property
+ def log_analytics_models(self):
+ self.log('Getting log analytics models')
+ return LogAnalyticsModels
+
+ @property
+ def servicebus_client(self):
+ self.log('Getting servicebus client')
+ if not self._servicebus_client:
+ self._servicebus_client = self.get_mgmt_svc_client(ServiceBusManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._servicebus_client
+
+ @property
+ def servicebus_models(self):
+ return ServicebusModel
+
+ @property
+ def automation_client(self):
+ self.log('Getting automation client')
+ if not self._automation_client:
+ self._automation_client = self.get_mgmt_svc_client(AutomationClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._automation_client
+
+ @property
+ def automation_models(self):
+ return AutomationModel
+
+ @property
+ def IoThub_client(self):
+ self.log('Getting iothub client')
+ if not self._IoThub_client:
+ self._IoThub_client = self.get_mgmt_svc_client(IotHubClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._IoThub_client
+
+ @property
+ def IoThub_models(self):
+ return IoTHubModels
+
+ @property
+ def automation_client(self):
+ self.log('Getting automation client')
+ if not self._automation_client:
+ self._automation_client = self.get_mgmt_svc_client(AutomationClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+ return self._automation_client
+
+ @property
+ def automation_models(self):
+ return AutomationModel
+
+ @property
+ def lock_client(self):
+ self.log('Getting lock client')
+ if not self._lock_client:
+ self._lock_client = self.get_mgmt_svc_client(ManagementLockClient,
+ base_url=self._cloud_environment.endpoints.resource_manager,
+ api_version='2016-09-01')
+ return self._lock_client
+
+ @property
+ def lock_models(self):
+ self.log("Getting lock models")
+ return ManagementLockClient.models('2016-09-01')
+
+
+class AzureSASAuthentication(Authentication):
+ """Simple SAS Authentication.
+ An implementation of Authentication in
+ https://github.com/Azure/msrest-for-python/blob/0732bc90bdb290e5f58c675ffdd7dbfa9acefc93/msrest/authentication.py
+
+ :param str token: SAS token
+ """
+ def __init__(self, token):
+ self.token = token
+
+ def signed_session(self):
+ session = super(AzureSASAuthentication, self).signed_session()
+ session.headers['Authorization'] = self.token
+ return session
+
+
+class AzureRMAuthException(Exception):
+ pass
+
+
+class AzureRMAuth(object):
+ def __init__(self, auth_source='auto', profile=None, subscription_id=None, client_id=None, secret=None,
+ tenant=None, ad_user=None, password=None, cloud_environment='AzureCloud', cert_validation_mode='validate',
+ api_profile='latest', adfs_authority_url=None, fail_impl=None, **kwargs):
+
+ if fail_impl:
+ self._fail_impl = fail_impl
+ else:
+ self._fail_impl = self._default_fail_impl
+
+ self._cloud_environment = None
+ self._adfs_authority_url = None
+
+ # authenticate
+ self.credentials = self._get_credentials(
+ dict(auth_source=auth_source, profile=profile, subscription_id=subscription_id, client_id=client_id, secret=secret,
+ tenant=tenant, ad_user=ad_user, password=password, cloud_environment=cloud_environment,
+ cert_validation_mode=cert_validation_mode, api_profile=api_profile, adfs_authority_url=adfs_authority_url))
+
+ if not self.credentials:
+ if HAS_AZURE_CLI_CORE:
+ self.fail("Failed to get credentials. Either pass as parameters, set environment variables, "
+ "define a profile in ~/.azure/credentials, or log in with Azure CLI (`az login`).")
+ else:
+ self.fail("Failed to get credentials. Either pass as parameters, set environment variables, "
+ "define a profile in ~/.azure/credentials, or install Azure CLI and log in (`az login`).")
+
+ # cert validation mode precedence: module-arg, credential profile, env, "validate"
+ self._cert_validation_mode = cert_validation_mode or self.credentials.get('cert_validation_mode') or \
+ os.environ.get('AZURE_CERT_VALIDATION_MODE') or 'validate'
+
+ if self._cert_validation_mode not in ['validate', 'ignore']:
+ self.fail('invalid cert_validation_mode: {0}'.format(self._cert_validation_mode))
+
+ # if cloud_environment specified, look up/build Cloud object
+ raw_cloud_env = self.credentials.get('cloud_environment')
+ if self.credentials.get('credentials') is not None and raw_cloud_env is not None:
+ self._cloud_environment = raw_cloud_env
+ elif not raw_cloud_env:
+ self._cloud_environment = azure_cloud.AZURE_PUBLIC_CLOUD # SDK default
+ else:
+ # try to look up "well-known" values via the name attribute on azure_cloud members
+ all_clouds = [x[1] for x in inspect.getmembers(azure_cloud) if isinstance(x[1], azure_cloud.Cloud)]
+ matched_clouds = [x for x in all_clouds if x.name == raw_cloud_env]
+ if len(matched_clouds) == 1:
+ self._cloud_environment = matched_clouds[0]
+ elif len(matched_clouds) > 1:
+ self.fail("Azure SDK failure: more than one cloud matched for cloud_environment name '{0}'".format(raw_cloud_env))
+ else:
+ if not urlparse.urlparse(raw_cloud_env).scheme:
+ self.fail("cloud_environment must be an endpoint discovery URL or one of {0}".format([x.name for x in all_clouds]))
+ try:
+ self._cloud_environment = azure_cloud.get_cloud_from_metadata_endpoint(raw_cloud_env)
+ except Exception as e:
+ self.fail("cloud_environment {0} could not be resolved: {1}".format(raw_cloud_env, e.message), exception=traceback.format_exc())
+
+ if self.credentials.get('subscription_id', None) is None and self.credentials.get('credentials') is None:
+ self.fail("Credentials did not include a subscription_id value.")
+ self.log("setting subscription_id")
+ self.subscription_id = self.credentials['subscription_id']
+
+ # get authentication authority
+ # for adfs, user could pass in authority or not.
+ # for others, use default authority from cloud environment
+ if self.credentials.get('adfs_authority_url') is None:
+ self._adfs_authority_url = self._cloud_environment.endpoints.active_directory
+ else:
+ self._adfs_authority_url = self.credentials.get('adfs_authority_url')
+
+ # get resource from cloud environment
+ self._resource = self._cloud_environment.endpoints.active_directory_resource_id
+
+ if self.credentials.get('credentials') is not None:
+ # AzureCLI credentials
+ self.azure_credentials = self.credentials['credentials']
+ elif self.credentials.get('client_id') is not None and \
+ self.credentials.get('secret') is not None and \
+ self.credentials.get('tenant') is not None:
+ self.azure_credentials = ServicePrincipalCredentials(client_id=self.credentials['client_id'],
+ secret=self.credentials['secret'],
+ tenant=self.credentials['tenant'],
+ cloud_environment=self._cloud_environment,
+ verify=self._cert_validation_mode == 'validate')
+
+ elif self.credentials.get('ad_user') is not None and \
+ self.credentials.get('password') is not None and \
+ self.credentials.get('client_id') is not None and \
+ self.credentials.get('tenant') is not None:
+
+ self.azure_credentials = self.acquire_token_with_username_password(
+ self._adfs_authority_url,
+ self._resource,
+ self.credentials['ad_user'],
+ self.credentials['password'],
+ self.credentials['client_id'],
+ self.credentials['tenant'])
+
+ elif self.credentials.get('ad_user') is not None and self.credentials.get('password') is not None:
+ tenant = self.credentials.get('tenant')
+ if not tenant:
+ tenant = 'common' # SDK default
+
+ self.azure_credentials = UserPassCredentials(self.credentials['ad_user'],
+ self.credentials['password'],
+ tenant=tenant,
+ cloud_environment=self._cloud_environment,
+ verify=self._cert_validation_mode == 'validate')
+ else:
+ self.fail("Failed to authenticate with provided credentials. Some attributes were missing. "
+ "Credentials must include client_id, secret and tenant or ad_user and password, or "
+ "ad_user, password, client_id, tenant and adfs_authority_url(optional) for ADFS authentication, or "
+ "be logged in using AzureCLI.")
+
+ def fail(self, msg, exception=None, **kwargs):
+ self._fail_impl(msg)
+
+ def _default_fail_impl(self, msg, exception=None, **kwargs):
+ raise AzureRMAuthException(msg)
+
+ def _get_profile(self, profile="default"):
+ path = expanduser("~/.azure/credentials")
+ try:
+ config = configparser.ConfigParser()
+ config.read(path)
+ except Exception as exc:
+ self.fail("Failed to access {0}. Check that the file exists and you have read "
+ "access. {1}".format(path, str(exc)))
+ credentials = dict()
+ for key in AZURE_CREDENTIAL_ENV_MAPPING:
+ try:
+ credentials[key] = config.get(profile, key, raw=True)
+ except Exception:
+ pass
+
+ if credentials.get('subscription_id'):
+ return credentials
+
+ return None
+
+ def _get_msi_credentials(self, subscription_id_param=None, **kwargs):
+ client_id = kwargs.get('client_id', None)
+ credentials = MSIAuthentication(client_id=client_id)
+ subscription_id = subscription_id_param or os.environ.get(AZURE_CREDENTIAL_ENV_MAPPING['subscription_id'], None)
+ if not subscription_id:
+ try:
+ # use the first subscription of the MSI
+ subscription_client = SubscriptionClient(credentials)
+ subscription = next(subscription_client.subscriptions.list())
+ subscription_id = str(subscription.subscription_id)
+ except Exception as exc:
+ self.fail("Failed to get MSI token: {0}. "
+ "Please check whether your machine enabled MSI or grant access to any subscription.".format(str(exc)))
+ return {
+ 'credentials': credentials,
+ 'subscription_id': subscription_id
+ }
+
+ def _get_azure_cli_credentials(self):
+ credentials, subscription_id = get_azure_cli_credentials()
+ cloud_environment = get_cli_active_cloud()
+
+ cli_credentials = {
+ 'credentials': credentials,
+ 'subscription_id': subscription_id,
+ 'cloud_environment': cloud_environment
+ }
+ return cli_credentials
+
+ def _get_env_credentials(self):
+ env_credentials = dict()
+ for attribute, env_variable in AZURE_CREDENTIAL_ENV_MAPPING.items():
+ env_credentials[attribute] = os.environ.get(env_variable, None)
+
+ if env_credentials['profile']:
+ credentials = self._get_profile(env_credentials['profile'])
+ return credentials
+
+ if env_credentials.get('subscription_id') is not None:
+ return env_credentials
+
+ return None
+
+ # TODO: use explicit kwargs instead of intermediate dict
+ def _get_credentials(self, params):
+ # Get authentication credentials.
+ self.log('Getting credentials')
+
+ arg_credentials = dict()
+ for attribute, env_variable in AZURE_CREDENTIAL_ENV_MAPPING.items():
+ arg_credentials[attribute] = params.get(attribute, None)
+
+ auth_source = params.get('auth_source', None)
+ if not auth_source:
+ auth_source = os.environ.get('ANSIBLE_AZURE_AUTH_SOURCE', 'auto')
+
+ if auth_source == 'msi':
+ self.log('Retrieving credenitals from MSI')
+ return self._get_msi_credentials(arg_credentials['subscription_id'], client_id=params.get('client_id', None))
+
+ if auth_source == 'cli':
+ if not HAS_AZURE_CLI_CORE:
+ self.fail(msg=missing_required_lib('azure-cli', reason='for `cli` auth_source'),
+ exception=HAS_AZURE_CLI_CORE_EXC)
+ try:
+ self.log('Retrieving credentials from Azure CLI profile')
+ cli_credentials = self._get_azure_cli_credentials()
+ return cli_credentials
+ except CLIError as err:
+ self.fail("Azure CLI profile cannot be loaded - {0}".format(err))
+
+ if auth_source == 'env':
+ self.log('Retrieving credentials from environment')
+ env_credentials = self._get_env_credentials()
+ return env_credentials
+
+ if auth_source == 'credential_file':
+ self.log("Retrieving credentials from credential file")
+ profile = params.get('profile') or 'default'
+ default_credentials = self._get_profile(profile)
+ return default_credentials
+
+ # auto, precedence: module parameters -> environment variables -> default profile in ~/.azure/credentials
+ # try module params
+ if arg_credentials['profile'] is not None:
+ self.log('Retrieving credentials with profile parameter.')
+ credentials = self._get_profile(arg_credentials['profile'])
+ return credentials
+
+ if arg_credentials['subscription_id']:
+ self.log('Received credentials from parameters.')
+ return arg_credentials
+
+ # try environment
+ env_credentials = self._get_env_credentials()
+ if env_credentials:
+ self.log('Received credentials from env.')
+ return env_credentials
+
+ # try default profile from ~./azure/credentials
+ default_credentials = self._get_profile()
+ if default_credentials:
+ self.log('Retrieved default profile credentials from ~/.azure/credentials.')
+ return default_credentials
+
+ try:
+ if HAS_AZURE_CLI_CORE:
+ self.log('Retrieving credentials from AzureCLI profile')
+ cli_credentials = self._get_azure_cli_credentials()
+ return cli_credentials
+ except CLIError as ce:
+ self.log('Error getting AzureCLI profile credentials - {0}'.format(ce))
+
+ return None
+
+ def acquire_token_with_username_password(self, authority, resource, username, password, client_id, tenant):
+ authority_uri = authority
+
+ if tenant is not None:
+ authority_uri = authority + '/' + tenant
+
+ context = AuthenticationContext(authority_uri)
+ token_response = context.acquire_token_with_username_password(resource, username, password, client_id)
+
+ return AADTokenCredentials(token_response)
+
+ def log(self, msg, pretty_print=False):
+ pass
+ # Use only during module development
+ # if self.debug:
+ # log_file = open('azure_rm.log', 'a')
+ # if pretty_print:
+ # log_file.write(json.dumps(msg, indent=4, sort_keys=True))
+ # else:
+ # log_file.write(msg + u'\n')
diff --git a/test/support/integration/plugins/module_utils/azure_rm_common_rest.py b/test/support/integration/plugins/module_utils/azure_rm_common_rest.py
new file mode 100644
index 00000000..4fd7eaa3
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/azure_rm_common_rest.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from ansible.module_utils.ansible_release import __version__ as ANSIBLE_VERSION
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrestazure.azure_configuration import AzureConfiguration
+ from msrest.service_client import ServiceClient
+ from msrest.pipeline import ClientRawResponse
+ from msrest.polling import LROPoller
+ from msrestazure.polling.arm_polling import ARMPolling
+ import uuid
+ import json
+except ImportError:
+ # This is handled in azure_rm_common
+ AzureConfiguration = object
+
+ANSIBLE_USER_AGENT = 'Ansible/{0}'.format(ANSIBLE_VERSION)
+
+
+class GenericRestClientConfiguration(AzureConfiguration):
+
+ def __init__(self, credentials, subscription_id, base_url=None):
+
+ if credentials is None:
+ raise ValueError("Parameter 'credentials' must not be None.")
+ if subscription_id is None:
+ raise ValueError("Parameter 'subscription_id' must not be None.")
+ if not base_url:
+ base_url = 'https://management.azure.com'
+
+ super(GenericRestClientConfiguration, self).__init__(base_url)
+
+ self.add_user_agent(ANSIBLE_USER_AGENT)
+
+ self.credentials = credentials
+ self.subscription_id = subscription_id
+
+
+class GenericRestClient(object):
+
+ def __init__(self, credentials, subscription_id, base_url=None):
+ self.config = GenericRestClientConfiguration(credentials, subscription_id, base_url)
+ self._client = ServiceClient(self.config.credentials, self.config)
+ self.models = None
+
+ def query(self, url, method, query_parameters, header_parameters, body, expected_status_codes, polling_timeout, polling_interval):
+ # Construct and send request
+ operation_config = {}
+
+ request = None
+
+ if header_parameters is None:
+ header_parameters = {}
+
+ header_parameters['x-ms-client-request-id'] = str(uuid.uuid1())
+
+ if method == 'GET':
+ request = self._client.get(url, query_parameters)
+ elif method == 'PUT':
+ request = self._client.put(url, query_parameters)
+ elif method == 'POST':
+ request = self._client.post(url, query_parameters)
+ elif method == 'HEAD':
+ request = self._client.head(url, query_parameters)
+ elif method == 'PATCH':
+ request = self._client.patch(url, query_parameters)
+ elif method == 'DELETE':
+ request = self._client.delete(url, query_parameters)
+ elif method == 'MERGE':
+ request = self._client.merge(url, query_parameters)
+
+ response = self._client.send(request, header_parameters, body, **operation_config)
+
+ if response.status_code not in expected_status_codes:
+ exp = CloudError(response)
+ exp.request_id = response.headers.get('x-ms-request-id')
+ raise exp
+ elif response.status_code == 202 and polling_timeout > 0:
+ def get_long_running_output(response):
+ return response
+ poller = LROPoller(self._client,
+ ClientRawResponse(None, response),
+ get_long_running_output,
+ ARMPolling(polling_interval, **operation_config))
+ response = self.get_poller_result(poller, polling_timeout)
+
+ return response
+
+ def get_poller_result(self, poller, timeout):
+ try:
+ poller.wait(timeout=timeout)
+ return poller.result()
+ except Exception as exc:
+ raise
diff --git a/test/support/integration/plugins/module_utils/cloud.py b/test/support/integration/plugins/module_utils/cloud.py
new file mode 100644
index 00000000..0d29071f
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/cloud.py
@@ -0,0 +1,217 @@
+#
+# (c) 2016 Allen Sanabria, <asanabria@linuxdynasty.org>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+"""
+This module adds shared support for generic cloud modules
+
+In order to use this module, include it as part of a custom
+module as shown below.
+
+from ansible.module_utils.cloud import CloudRetry
+
+The 'cloud' module provides the following common classes:
+
+ * CloudRetry
+ - The base class to be used by other cloud providers, in order to
+ provide a backoff/retry decorator based on status codes.
+
+ - Example using the AWSRetry class which inherits from CloudRetry.
+
+ @AWSRetry.exponential_backoff(retries=10, delay=3)
+ get_ec2_security_group_ids_from_names()
+
+ @AWSRetry.jittered_backoff()
+ get_ec2_security_group_ids_from_names()
+
+"""
+import random
+from functools import wraps
+import syslog
+import time
+
+
+def _exponential_backoff(retries=10, delay=2, backoff=2, max_delay=60):
+ """ Customizable exponential backoff strategy.
+ Args:
+ retries (int): Maximum number of times to retry a request.
+ delay (float): Initial (base) delay.
+ backoff (float): base of the exponent to use for exponential
+ backoff.
+ max_delay (int): Optional. If provided each delay generated is capped
+ at this amount. Defaults to 60 seconds.
+ Returns:
+ Callable that returns a generator. This generator yields durations in
+ seconds to be used as delays for an exponential backoff strategy.
+ Usage:
+ >>> backoff = _exponential_backoff()
+ >>> backoff
+ <function backoff_backoff at 0x7f0d939facf8>
+ >>> list(backoff())
+ [2, 4, 8, 16, 32, 60, 60, 60, 60, 60]
+ """
+ def backoff_gen():
+ for retry in range(0, retries):
+ sleep = delay * backoff ** retry
+ yield sleep if max_delay is None else min(sleep, max_delay)
+ return backoff_gen
+
+
+def _full_jitter_backoff(retries=10, delay=3, max_delay=60, _random=random):
+ """ Implements the "Full Jitter" backoff strategy described here
+ https://www.awsarchitectureblog.com/2015/03/backoff.html
+ Args:
+ retries (int): Maximum number of times to retry a request.
+ delay (float): Approximate number of seconds to sleep for the first
+ retry.
+ max_delay (int): The maximum number of seconds to sleep for any retry.
+ _random (random.Random or None): Makes this generator testable by
+ allowing developers to explicitly pass in the a seeded Random.
+ Returns:
+ Callable that returns a generator. This generator yields durations in
+ seconds to be used as delays for a full jitter backoff strategy.
+ Usage:
+ >>> backoff = _full_jitter_backoff(retries=5)
+ >>> backoff
+ <function backoff_backoff at 0x7f0d939facf8>
+ >>> list(backoff())
+ [3, 6, 5, 23, 38]
+ >>> list(backoff())
+ [2, 1, 6, 6, 31]
+ """
+ def backoff_gen():
+ for retry in range(0, retries):
+ yield _random.randint(0, min(max_delay, delay * 2 ** retry))
+ return backoff_gen
+
+
+class CloudRetry(object):
+ """ CloudRetry can be used by any cloud provider, in order to implement a
+ backoff algorithm/retry effect based on Status Code from Exceptions.
+ """
+ # This is the base class of the exception.
+ # AWS Example botocore.exceptions.ClientError
+ base_class = None
+
+ @staticmethod
+ def status_code_from_exception(error):
+ """ Return the status code from the exception object
+ Args:
+ error (object): The exception itself.
+ """
+ pass
+
+ @staticmethod
+ def found(response_code, catch_extra_error_codes=None):
+ """ Return True if the Response Code to retry on was found.
+ Args:
+ response_code (str): This is the Response Code that is being matched against.
+ """
+ pass
+
+ @classmethod
+ def _backoff(cls, backoff_strategy, catch_extra_error_codes=None):
+ """ Retry calling the Cloud decorated function using the provided
+ backoff strategy.
+ Args:
+ backoff_strategy (callable): Callable that returns a generator. The
+ generator should yield sleep times for each retry of the decorated
+ function.
+ """
+ def deco(f):
+ @wraps(f)
+ def retry_func(*args, **kwargs):
+ for delay in backoff_strategy():
+ try:
+ return f(*args, **kwargs)
+ except Exception as e:
+ if isinstance(e, cls.base_class):
+ response_code = cls.status_code_from_exception(e)
+ if cls.found(response_code, catch_extra_error_codes):
+ msg = "{0}: Retrying in {1} seconds...".format(str(e), delay)
+ syslog.syslog(syslog.LOG_INFO, msg)
+ time.sleep(delay)
+ else:
+ # Return original exception if exception is not a ClientError
+ raise e
+ else:
+ # Return original exception if exception is not a ClientError
+ raise e
+ return f(*args, **kwargs)
+
+ return retry_func # true decorator
+
+ return deco
+
+ @classmethod
+ def exponential_backoff(cls, retries=10, delay=3, backoff=2, max_delay=60, catch_extra_error_codes=None):
+ """
+ Retry calling the Cloud decorated function using an exponential backoff.
+
+ Kwargs:
+ retries (int): Number of times to retry a failed request before giving up
+ default=10
+ delay (int or float): Initial delay between retries in seconds
+ default=3
+ backoff (int or float): backoff multiplier e.g. value of 2 will
+ double the delay each retry
+ default=1.1
+ max_delay (int or None): maximum amount of time to wait between retries.
+ default=60
+ """
+ return cls._backoff(_exponential_backoff(
+ retries=retries, delay=delay, backoff=backoff, max_delay=max_delay), catch_extra_error_codes)
+
+ @classmethod
+ def jittered_backoff(cls, retries=10, delay=3, max_delay=60, catch_extra_error_codes=None):
+ """
+ Retry calling the Cloud decorated function using a jittered backoff
+ strategy. More on this strategy here:
+
+ https://www.awsarchitectureblog.com/2015/03/backoff.html
+
+ Kwargs:
+ retries (int): Number of times to retry a failed request before giving up
+ default=10
+ delay (int): Initial delay between retries in seconds
+ default=3
+ max_delay (int): maximum amount of time to wait between retries.
+ default=60
+ """
+ return cls._backoff(_full_jitter_backoff(
+ retries=retries, delay=delay, max_delay=max_delay), catch_extra_error_codes)
+
+ @classmethod
+ def backoff(cls, tries=10, delay=3, backoff=1.1, catch_extra_error_codes=None):
+ """
+ Retry calling the Cloud decorated function using an exponential backoff.
+
+ Compatibility for the original implementation of CloudRetry.backoff that
+ did not provide configurable backoff strategies. Developers should use
+ CloudRetry.exponential_backoff instead.
+
+ Kwargs:
+ tries (int): Number of times to try (not retry) before giving up
+ default=10
+ delay (int or float): Initial delay between retries in seconds
+ default=3
+ backoff (int or float): backoff multiplier e.g. value of 2 will
+ double the delay each retry
+ default=1.1
+ """
+ return cls.exponential_backoff(
+ retries=tries - 1, delay=delay, backoff=backoff, max_delay=None, catch_extra_error_codes=catch_extra_error_codes)
diff --git a/test/support/integration/plugins/module_utils/compat/__init__.py b/test/support/integration/plugins/module_utils/compat/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/compat/__init__.py
diff --git a/test/support/integration/plugins/module_utils/compat/ipaddress.py b/test/support/integration/plugins/module_utils/compat/ipaddress.py
new file mode 100644
index 00000000..c46ad72a
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/compat/ipaddress.py
@@ -0,0 +1,2476 @@
+# -*- coding: utf-8 -*-
+
+# This code is part of Ansible, but is an independent component.
+# This particular file, and this file only, is based on
+# Lib/ipaddress.py of cpython
+# It is licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
+#
+# 1. This LICENSE AGREEMENT is between the Python Software Foundation
+# ("PSF"), and the Individual or Organization ("Licensee") accessing and
+# otherwise using this software ("Python") in source or binary form and
+# its associated documentation.
+#
+# 2. Subject to the terms and conditions of this License Agreement, PSF hereby
+# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
+# analyze, test, perform and/or display publicly, prepare derivative works,
+# distribute, and otherwise use Python alone or in any derivative version,
+# provided, however, that PSF's License Agreement and PSF's notice of copyright,
+# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
+# 2011, 2012, 2013, 2014, 2015 Python Software Foundation; All Rights Reserved"
+# are retained in Python alone or in any derivative version prepared by Licensee.
+#
+# 3. In the event Licensee prepares a derivative work that is based on
+# or incorporates Python or any part thereof, and wants to make
+# the derivative work available to others as provided herein, then
+# Licensee hereby agrees to include in any such work a brief summary of
+# the changes made to Python.
+#
+# 4. PSF is making Python available to Licensee on an "AS IS"
+# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
+# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
+# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
+# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
+# INFRINGE ANY THIRD PARTY RIGHTS.
+#
+# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
+# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
+# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
+# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
+#
+# 6. This License Agreement will automatically terminate upon a material
+# breach of its terms and conditions.
+#
+# 7. Nothing in this License Agreement shall be deemed to create any
+# relationship of agency, partnership, or joint venture between PSF and
+# Licensee. This License Agreement does not grant permission to use PSF
+# trademarks or trade name in a trademark sense to endorse or promote
+# products or services of Licensee, or any third party.
+#
+# 8. By copying, installing or otherwise using Python, Licensee
+# agrees to be bound by the terms and conditions of this License
+# Agreement.
+
+# Copyright 2007 Google Inc.
+# Licensed to PSF under a Contributor Agreement.
+
+"""A fast, lightweight IPv4/IPv6 manipulation library in Python.
+
+This library is used to create/poke/manipulate IPv4 and IPv6 addresses
+and networks.
+
+"""
+
+from __future__ import unicode_literals
+
+
+import itertools
+import struct
+
+
+# The following makes it easier for us to script updates of the bundled code and is not part of
+# upstream
+_BUNDLED_METADATA = {"pypi_name": "ipaddress", "version": "1.0.22"}
+
+__version__ = '1.0.22'
+
+# Compatibility functions
+_compat_int_types = (int,)
+try:
+ _compat_int_types = (int, long)
+except NameError:
+ pass
+try:
+ _compat_str = unicode
+except NameError:
+ _compat_str = str
+ assert bytes != str
+if b'\0'[0] == 0: # Python 3 semantics
+ def _compat_bytes_to_byte_vals(byt):
+ return byt
+else:
+ def _compat_bytes_to_byte_vals(byt):
+ return [struct.unpack(b'!B', b)[0] for b in byt]
+try:
+ _compat_int_from_byte_vals = int.from_bytes
+except AttributeError:
+ def _compat_int_from_byte_vals(bytvals, endianess):
+ assert endianess == 'big'
+ res = 0
+ for bv in bytvals:
+ assert isinstance(bv, _compat_int_types)
+ res = (res << 8) + bv
+ return res
+
+
+def _compat_to_bytes(intval, length, endianess):
+ assert isinstance(intval, _compat_int_types)
+ assert endianess == 'big'
+ if length == 4:
+ if intval < 0 or intval >= 2 ** 32:
+ raise struct.error("integer out of range for 'I' format code")
+ return struct.pack(b'!I', intval)
+ elif length == 16:
+ if intval < 0 or intval >= 2 ** 128:
+ raise struct.error("integer out of range for 'QQ' format code")
+ return struct.pack(b'!QQ', intval >> 64, intval & 0xffffffffffffffff)
+ else:
+ raise NotImplementedError()
+
+
+if hasattr(int, 'bit_length'):
+ # Not int.bit_length , since that won't work in 2.7 where long exists
+ def _compat_bit_length(i):
+ return i.bit_length()
+else:
+ def _compat_bit_length(i):
+ for res in itertools.count():
+ if i >> res == 0:
+ return res
+
+
+def _compat_range(start, end, step=1):
+ assert step > 0
+ i = start
+ while i < end:
+ yield i
+ i += step
+
+
+class _TotalOrderingMixin(object):
+ __slots__ = ()
+
+ # Helper that derives the other comparison operations from
+ # __lt__ and __eq__
+ # We avoid functools.total_ordering because it doesn't handle
+ # NotImplemented correctly yet (http://bugs.python.org/issue10042)
+ def __eq__(self, other):
+ raise NotImplementedError
+
+ def __ne__(self, other):
+ equal = self.__eq__(other)
+ if equal is NotImplemented:
+ return NotImplemented
+ return not equal
+
+ def __lt__(self, other):
+ raise NotImplementedError
+
+ def __le__(self, other):
+ less = self.__lt__(other)
+ if less is NotImplemented or not less:
+ return self.__eq__(other)
+ return less
+
+ def __gt__(self, other):
+ less = self.__lt__(other)
+ if less is NotImplemented:
+ return NotImplemented
+ equal = self.__eq__(other)
+ if equal is NotImplemented:
+ return NotImplemented
+ return not (less or equal)
+
+ def __ge__(self, other):
+ less = self.__lt__(other)
+ if less is NotImplemented:
+ return NotImplemented
+ return not less
+
+
+IPV4LENGTH = 32
+IPV6LENGTH = 128
+
+
+class AddressValueError(ValueError):
+ """A Value Error related to the address."""
+
+
+class NetmaskValueError(ValueError):
+ """A Value Error related to the netmask."""
+
+
+def ip_address(address):
+ """Take an IP string/int and return an object of the correct type.
+
+ Args:
+ address: A string or integer, the IP address. Either IPv4 or
+ IPv6 addresses may be supplied; integers less than 2**32 will
+ be considered to be IPv4 by default.
+
+ Returns:
+ An IPv4Address or IPv6Address object.
+
+ Raises:
+ ValueError: if the *address* passed isn't either a v4 or a v6
+ address
+
+ """
+ try:
+ return IPv4Address(address)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ try:
+ return IPv6Address(address)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ if isinstance(address, bytes):
+ raise AddressValueError(
+ '%r does not appear to be an IPv4 or IPv6 address. '
+ 'Did you pass in a bytes (str in Python 2) instead of'
+ ' a unicode object?' % address)
+
+ raise ValueError('%r does not appear to be an IPv4 or IPv6 address' %
+ address)
+
+
+def ip_network(address, strict=True):
+ """Take an IP string/int and return an object of the correct type.
+
+ Args:
+ address: A string or integer, the IP network. Either IPv4 or
+ IPv6 networks may be supplied; integers less than 2**32 will
+ be considered to be IPv4 by default.
+
+ Returns:
+ An IPv4Network or IPv6Network object.
+
+ Raises:
+ ValueError: if the string passed isn't either a v4 or a v6
+ address. Or if the network has host bits set.
+
+ """
+ try:
+ return IPv4Network(address, strict)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ try:
+ return IPv6Network(address, strict)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ if isinstance(address, bytes):
+ raise AddressValueError(
+ '%r does not appear to be an IPv4 or IPv6 network. '
+ 'Did you pass in a bytes (str in Python 2) instead of'
+ ' a unicode object?' % address)
+
+ raise ValueError('%r does not appear to be an IPv4 or IPv6 network' %
+ address)
+
+
+def ip_interface(address):
+ """Take an IP string/int and return an object of the correct type.
+
+ Args:
+ address: A string or integer, the IP address. Either IPv4 or
+ IPv6 addresses may be supplied; integers less than 2**32 will
+ be considered to be IPv4 by default.
+
+ Returns:
+ An IPv4Interface or IPv6Interface object.
+
+ Raises:
+ ValueError: if the string passed isn't either a v4 or a v6
+ address.
+
+ Notes:
+ The IPv?Interface classes describe an Address on a particular
+ Network, so they're basically a combination of both the Address
+ and Network classes.
+
+ """
+ try:
+ return IPv4Interface(address)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ try:
+ return IPv6Interface(address)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ raise ValueError('%r does not appear to be an IPv4 or IPv6 interface' %
+ address)
+
+
+def v4_int_to_packed(address):
+ """Represent an address as 4 packed bytes in network (big-endian) order.
+
+ Args:
+ address: An integer representation of an IPv4 IP address.
+
+ Returns:
+ The integer address packed as 4 bytes in network (big-endian) order.
+
+ Raises:
+ ValueError: If the integer is negative or too large to be an
+ IPv4 IP address.
+
+ """
+ try:
+ return _compat_to_bytes(address, 4, 'big')
+ except (struct.error, OverflowError):
+ raise ValueError("Address negative or too large for IPv4")
+
+
+def v6_int_to_packed(address):
+ """Represent an address as 16 packed bytes in network (big-endian) order.
+
+ Args:
+ address: An integer representation of an IPv6 IP address.
+
+ Returns:
+ The integer address packed as 16 bytes in network (big-endian) order.
+
+ """
+ try:
+ return _compat_to_bytes(address, 16, 'big')
+ except (struct.error, OverflowError):
+ raise ValueError("Address negative or too large for IPv6")
+
+
+def _split_optional_netmask(address):
+ """Helper to split the netmask and raise AddressValueError if needed"""
+ addr = _compat_str(address).split('/')
+ if len(addr) > 2:
+ raise AddressValueError("Only one '/' permitted in %r" % address)
+ return addr
+
+
+def _find_address_range(addresses):
+ """Find a sequence of sorted deduplicated IPv#Address.
+
+ Args:
+ addresses: a list of IPv#Address objects.
+
+ Yields:
+ A tuple containing the first and last IP addresses in the sequence.
+
+ """
+ it = iter(addresses)
+ first = last = next(it) # pylint: disable=stop-iteration-return
+ for ip in it:
+ if ip._ip != last._ip + 1:
+ yield first, last
+ first = ip
+ last = ip
+ yield first, last
+
+
+def _count_righthand_zero_bits(number, bits):
+ """Count the number of zero bits on the right hand side.
+
+ Args:
+ number: an integer.
+ bits: maximum number of bits to count.
+
+ Returns:
+ The number of zero bits on the right hand side of the number.
+
+ """
+ if number == 0:
+ return bits
+ return min(bits, _compat_bit_length(~number & (number - 1)))
+
+
+def summarize_address_range(first, last):
+ """Summarize a network range given the first and last IP addresses.
+
+ Example:
+ >>> list(summarize_address_range(IPv4Address('192.0.2.0'),
+ ... IPv4Address('192.0.2.130')))
+ ... #doctest: +NORMALIZE_WHITESPACE
+ [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'),
+ IPv4Network('192.0.2.130/32')]
+
+ Args:
+ first: the first IPv4Address or IPv6Address in the range.
+ last: the last IPv4Address or IPv6Address in the range.
+
+ Returns:
+ An iterator of the summarized IPv(4|6) network objects.
+
+ Raise:
+ TypeError:
+ If the first and last objects are not IP addresses.
+ If the first and last objects are not the same version.
+ ValueError:
+ If the last object is not greater than the first.
+ If the version of the first address is not 4 or 6.
+
+ """
+ if (not (isinstance(first, _BaseAddress) and
+ isinstance(last, _BaseAddress))):
+ raise TypeError('first and last must be IP addresses, not networks')
+ if first.version != last.version:
+ raise TypeError("%s and %s are not of the same version" % (
+ first, last))
+ if first > last:
+ raise ValueError('last IP address must be greater than first')
+
+ if first.version == 4:
+ ip = IPv4Network
+ elif first.version == 6:
+ ip = IPv6Network
+ else:
+ raise ValueError('unknown IP version')
+
+ ip_bits = first._max_prefixlen
+ first_int = first._ip
+ last_int = last._ip
+ while first_int <= last_int:
+ nbits = min(_count_righthand_zero_bits(first_int, ip_bits),
+ _compat_bit_length(last_int - first_int + 1) - 1)
+ net = ip((first_int, ip_bits - nbits))
+ yield net
+ first_int += 1 << nbits
+ if first_int - 1 == ip._ALL_ONES:
+ break
+
+
+def _collapse_addresses_internal(addresses):
+ """Loops through the addresses, collapsing concurrent netblocks.
+
+ Example:
+
+ ip1 = IPv4Network('192.0.2.0/26')
+ ip2 = IPv4Network('192.0.2.64/26')
+ ip3 = IPv4Network('192.0.2.128/26')
+ ip4 = IPv4Network('192.0.2.192/26')
+
+ _collapse_addresses_internal([ip1, ip2, ip3, ip4]) ->
+ [IPv4Network('192.0.2.0/24')]
+
+ This shouldn't be called directly; it is called via
+ collapse_addresses([]).
+
+ Args:
+ addresses: A list of IPv4Network's or IPv6Network's
+
+ Returns:
+ A list of IPv4Network's or IPv6Network's depending on what we were
+ passed.
+
+ """
+ # First merge
+ to_merge = list(addresses)
+ subnets = {}
+ while to_merge:
+ net = to_merge.pop()
+ supernet = net.supernet()
+ existing = subnets.get(supernet)
+ if existing is None:
+ subnets[supernet] = net
+ elif existing != net:
+ # Merge consecutive subnets
+ del subnets[supernet]
+ to_merge.append(supernet)
+ # Then iterate over resulting networks, skipping subsumed subnets
+ last = None
+ for net in sorted(subnets.values()):
+ if last is not None:
+ # Since they are sorted,
+ # last.network_address <= net.network_address is a given.
+ if last.broadcast_address >= net.broadcast_address:
+ continue
+ yield net
+ last = net
+
+
+def collapse_addresses(addresses):
+ """Collapse a list of IP objects.
+
+ Example:
+ collapse_addresses([IPv4Network('192.0.2.0/25'),
+ IPv4Network('192.0.2.128/25')]) ->
+ [IPv4Network('192.0.2.0/24')]
+
+ Args:
+ addresses: An iterator of IPv4Network or IPv6Network objects.
+
+ Returns:
+ An iterator of the collapsed IPv(4|6)Network objects.
+
+ Raises:
+ TypeError: If passed a list of mixed version objects.
+
+ """
+ addrs = []
+ ips = []
+ nets = []
+
+ # split IP addresses and networks
+ for ip in addresses:
+ if isinstance(ip, _BaseAddress):
+ if ips and ips[-1]._version != ip._version:
+ raise TypeError("%s and %s are not of the same version" % (
+ ip, ips[-1]))
+ ips.append(ip)
+ elif ip._prefixlen == ip._max_prefixlen:
+ if ips and ips[-1]._version != ip._version:
+ raise TypeError("%s and %s are not of the same version" % (
+ ip, ips[-1]))
+ try:
+ ips.append(ip.ip)
+ except AttributeError:
+ ips.append(ip.network_address)
+ else:
+ if nets and nets[-1]._version != ip._version:
+ raise TypeError("%s and %s are not of the same version" % (
+ ip, nets[-1]))
+ nets.append(ip)
+
+ # sort and dedup
+ ips = sorted(set(ips))
+
+ # find consecutive address ranges in the sorted sequence and summarize them
+ if ips:
+ for first, last in _find_address_range(ips):
+ addrs.extend(summarize_address_range(first, last))
+
+ return _collapse_addresses_internal(addrs + nets)
+
+
+def get_mixed_type_key(obj):
+ """Return a key suitable for sorting between networks and addresses.
+
+ Address and Network objects are not sortable by default; they're
+ fundamentally different so the expression
+
+ IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24')
+
+ doesn't make any sense. There are some times however, where you may wish
+ to have ipaddress sort these for you anyway. If you need to do this, you
+ can use this function as the key= argument to sorted().
+
+ Args:
+ obj: either a Network or Address object.
+ Returns:
+ appropriate key.
+
+ """
+ if isinstance(obj, _BaseNetwork):
+ return obj._get_networks_key()
+ elif isinstance(obj, _BaseAddress):
+ return obj._get_address_key()
+ return NotImplemented
+
+
+class _IPAddressBase(_TotalOrderingMixin):
+
+ """The mother class."""
+
+ __slots__ = ()
+
+ @property
+ def exploded(self):
+ """Return the longhand version of the IP address as a string."""
+ return self._explode_shorthand_ip_string()
+
+ @property
+ def compressed(self):
+ """Return the shorthand version of the IP address as a string."""
+ return _compat_str(self)
+
+ @property
+ def reverse_pointer(self):
+ """The name of the reverse DNS pointer for the IP address, e.g.:
+ >>> ipaddress.ip_address("127.0.0.1").reverse_pointer
+ '1.0.0.127.in-addr.arpa'
+ >>> ipaddress.ip_address("2001:db8::1").reverse_pointer
+ '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa'
+
+ """
+ return self._reverse_pointer()
+
+ @property
+ def version(self):
+ msg = '%200s has no version specified' % (type(self),)
+ raise NotImplementedError(msg)
+
+ def _check_int_address(self, address):
+ if address < 0:
+ msg = "%d (< 0) is not permitted as an IPv%d address"
+ raise AddressValueError(msg % (address, self._version))
+ if address > self._ALL_ONES:
+ msg = "%d (>= 2**%d) is not permitted as an IPv%d address"
+ raise AddressValueError(msg % (address, self._max_prefixlen,
+ self._version))
+
+ def _check_packed_address(self, address, expected_len):
+ address_len = len(address)
+ if address_len != expected_len:
+ msg = (
+ '%r (len %d != %d) is not permitted as an IPv%d address. '
+ 'Did you pass in a bytes (str in Python 2) instead of'
+ ' a unicode object?')
+ raise AddressValueError(msg % (address, address_len,
+ expected_len, self._version))
+
+ @classmethod
+ def _ip_int_from_prefix(cls, prefixlen):
+ """Turn the prefix length into a bitwise netmask
+
+ Args:
+ prefixlen: An integer, the prefix length.
+
+ Returns:
+ An integer.
+
+ """
+ return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen)
+
+ @classmethod
+ def _prefix_from_ip_int(cls, ip_int):
+ """Return prefix length from the bitwise netmask.
+
+ Args:
+ ip_int: An integer, the netmask in expanded bitwise format
+
+ Returns:
+ An integer, the prefix length.
+
+ Raises:
+ ValueError: If the input intermingles zeroes & ones
+ """
+ trailing_zeroes = _count_righthand_zero_bits(ip_int,
+ cls._max_prefixlen)
+ prefixlen = cls._max_prefixlen - trailing_zeroes
+ leading_ones = ip_int >> trailing_zeroes
+ all_ones = (1 << prefixlen) - 1
+ if leading_ones != all_ones:
+ byteslen = cls._max_prefixlen // 8
+ details = _compat_to_bytes(ip_int, byteslen, 'big')
+ msg = 'Netmask pattern %r mixes zeroes & ones'
+ raise ValueError(msg % details)
+ return prefixlen
+
+ @classmethod
+ def _report_invalid_netmask(cls, netmask_str):
+ msg = '%r is not a valid netmask' % netmask_str
+ raise NetmaskValueError(msg)
+
+ @classmethod
+ def _prefix_from_prefix_string(cls, prefixlen_str):
+ """Return prefix length from a numeric string
+
+ Args:
+ prefixlen_str: The string to be converted
+
+ Returns:
+ An integer, the prefix length.
+
+ Raises:
+ NetmaskValueError: If the input is not a valid netmask
+ """
+ # int allows a leading +/- as well as surrounding whitespace,
+ # so we ensure that isn't the case
+ if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str):
+ cls._report_invalid_netmask(prefixlen_str)
+ try:
+ prefixlen = int(prefixlen_str)
+ except ValueError:
+ cls._report_invalid_netmask(prefixlen_str)
+ if not (0 <= prefixlen <= cls._max_prefixlen):
+ cls._report_invalid_netmask(prefixlen_str)
+ return prefixlen
+
+ @classmethod
+ def _prefix_from_ip_string(cls, ip_str):
+ """Turn a netmask/hostmask string into a prefix length
+
+ Args:
+ ip_str: The netmask/hostmask to be converted
+
+ Returns:
+ An integer, the prefix length.
+
+ Raises:
+ NetmaskValueError: If the input is not a valid netmask/hostmask
+ """
+ # Parse the netmask/hostmask like an IP address.
+ try:
+ ip_int = cls._ip_int_from_string(ip_str)
+ except AddressValueError:
+ cls._report_invalid_netmask(ip_str)
+
+ # Try matching a netmask (this would be /1*0*/ as a bitwise regexp).
+ # Note that the two ambiguous cases (all-ones and all-zeroes) are
+ # treated as netmasks.
+ try:
+ return cls._prefix_from_ip_int(ip_int)
+ except ValueError:
+ pass
+
+ # Invert the bits, and try matching a /0+1+/ hostmask instead.
+ ip_int ^= cls._ALL_ONES
+ try:
+ return cls._prefix_from_ip_int(ip_int)
+ except ValueError:
+ cls._report_invalid_netmask(ip_str)
+
+ def __reduce__(self):
+ return self.__class__, (_compat_str(self),)
+
+
+class _BaseAddress(_IPAddressBase):
+
+ """A generic IP object.
+
+ This IP class contains the version independent methods which are
+ used by single IP addresses.
+ """
+
+ __slots__ = ()
+
+ def __int__(self):
+ return self._ip
+
+ def __eq__(self, other):
+ try:
+ return (self._ip == other._ip and
+ self._version == other._version)
+ except AttributeError:
+ return NotImplemented
+
+ def __lt__(self, other):
+ if not isinstance(other, _IPAddressBase):
+ return NotImplemented
+ if not isinstance(other, _BaseAddress):
+ raise TypeError('%s and %s are not of the same type' % (
+ self, other))
+ if self._version != other._version:
+ raise TypeError('%s and %s are not of the same version' % (
+ self, other))
+ if self._ip != other._ip:
+ return self._ip < other._ip
+ return False
+
+ # Shorthand for Integer addition and subtraction. This is not
+ # meant to ever support addition/subtraction of addresses.
+ def __add__(self, other):
+ if not isinstance(other, _compat_int_types):
+ return NotImplemented
+ return self.__class__(int(self) + other)
+
+ def __sub__(self, other):
+ if not isinstance(other, _compat_int_types):
+ return NotImplemented
+ return self.__class__(int(self) - other)
+
+ def __repr__(self):
+ return '%s(%r)' % (self.__class__.__name__, _compat_str(self))
+
+ def __str__(self):
+ return _compat_str(self._string_from_ip_int(self._ip))
+
+ def __hash__(self):
+ return hash(hex(int(self._ip)))
+
+ def _get_address_key(self):
+ return (self._version, self)
+
+ def __reduce__(self):
+ return self.__class__, (self._ip,)
+
+
+class _BaseNetwork(_IPAddressBase):
+
+ """A generic IP network object.
+
+ This IP class contains the version independent methods which are
+ used by networks.
+
+ """
+ def __init__(self, address):
+ self._cache = {}
+
+ def __repr__(self):
+ return '%s(%r)' % (self.__class__.__name__, _compat_str(self))
+
+ def __str__(self):
+ return '%s/%d' % (self.network_address, self.prefixlen)
+
+ def hosts(self):
+ """Generate Iterator over usable hosts in a network.
+
+ This is like __iter__ except it doesn't return the network
+ or broadcast addresses.
+
+ """
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ for x in _compat_range(network + 1, broadcast):
+ yield self._address_class(x)
+
+ def __iter__(self):
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ for x in _compat_range(network, broadcast + 1):
+ yield self._address_class(x)
+
+ def __getitem__(self, n):
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ if n >= 0:
+ if network + n > broadcast:
+ raise IndexError('address out of range')
+ return self._address_class(network + n)
+ else:
+ n += 1
+ if broadcast + n < network:
+ raise IndexError('address out of range')
+ return self._address_class(broadcast + n)
+
+ def __lt__(self, other):
+ if not isinstance(other, _IPAddressBase):
+ return NotImplemented
+ if not isinstance(other, _BaseNetwork):
+ raise TypeError('%s and %s are not of the same type' % (
+ self, other))
+ if self._version != other._version:
+ raise TypeError('%s and %s are not of the same version' % (
+ self, other))
+ if self.network_address != other.network_address:
+ return self.network_address < other.network_address
+ if self.netmask != other.netmask:
+ return self.netmask < other.netmask
+ return False
+
+ def __eq__(self, other):
+ try:
+ return (self._version == other._version and
+ self.network_address == other.network_address and
+ int(self.netmask) == int(other.netmask))
+ except AttributeError:
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(int(self.network_address) ^ int(self.netmask))
+
+ def __contains__(self, other):
+ # always false if one is v4 and the other is v6.
+ if self._version != other._version:
+ return False
+ # dealing with another network.
+ if isinstance(other, _BaseNetwork):
+ return False
+ # dealing with another address
+ else:
+ # address
+ return (int(self.network_address) <= int(other._ip) <=
+ int(self.broadcast_address))
+
+ def overlaps(self, other):
+ """Tell if self is partly contained in other."""
+ return self.network_address in other or (
+ self.broadcast_address in other or (
+ other.network_address in self or (
+ other.broadcast_address in self)))
+
+ @property
+ def broadcast_address(self):
+ x = self._cache.get('broadcast_address')
+ if x is None:
+ x = self._address_class(int(self.network_address) |
+ int(self.hostmask))
+ self._cache['broadcast_address'] = x
+ return x
+
+ @property
+ def hostmask(self):
+ x = self._cache.get('hostmask')
+ if x is None:
+ x = self._address_class(int(self.netmask) ^ self._ALL_ONES)
+ self._cache['hostmask'] = x
+ return x
+
+ @property
+ def with_prefixlen(self):
+ return '%s/%d' % (self.network_address, self._prefixlen)
+
+ @property
+ def with_netmask(self):
+ return '%s/%s' % (self.network_address, self.netmask)
+
+ @property
+ def with_hostmask(self):
+ return '%s/%s' % (self.network_address, self.hostmask)
+
+ @property
+ def num_addresses(self):
+ """Number of hosts in the current subnet."""
+ return int(self.broadcast_address) - int(self.network_address) + 1
+
+ @property
+ def _address_class(self):
+ # Returning bare address objects (rather than interfaces) allows for
+ # more consistent behaviour across the network address, broadcast
+ # address and individual host addresses.
+ msg = '%200s has no associated address class' % (type(self),)
+ raise NotImplementedError(msg)
+
+ @property
+ def prefixlen(self):
+ return self._prefixlen
+
+ def address_exclude(self, other):
+ """Remove an address from a larger block.
+
+ For example:
+
+ addr1 = ip_network('192.0.2.0/28')
+ addr2 = ip_network('192.0.2.1/32')
+ list(addr1.address_exclude(addr2)) =
+ [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'),
+ IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')]
+
+ or IPv6:
+
+ addr1 = ip_network('2001:db8::1/32')
+ addr2 = ip_network('2001:db8::1/128')
+ list(addr1.address_exclude(addr2)) =
+ [ip_network('2001:db8::1/128'),
+ ip_network('2001:db8::2/127'),
+ ip_network('2001:db8::4/126'),
+ ip_network('2001:db8::8/125'),
+ ...
+ ip_network('2001:db8:8000::/33')]
+
+ Args:
+ other: An IPv4Network or IPv6Network object of the same type.
+
+ Returns:
+ An iterator of the IPv(4|6)Network objects which is self
+ minus other.
+
+ Raises:
+ TypeError: If self and other are of differing address
+ versions, or if other is not a network object.
+ ValueError: If other is not completely contained by self.
+
+ """
+ if not self._version == other._version:
+ raise TypeError("%s and %s are not of the same version" % (
+ self, other))
+
+ if not isinstance(other, _BaseNetwork):
+ raise TypeError("%s is not a network object" % other)
+
+ if not other.subnet_of(self):
+ raise ValueError('%s not contained in %s' % (other, self))
+ if other == self:
+ return
+
+ # Make sure we're comparing the network of other.
+ other = other.__class__('%s/%s' % (other.network_address,
+ other.prefixlen))
+
+ s1, s2 = self.subnets()
+ while s1 != other and s2 != other:
+ if other.subnet_of(s1):
+ yield s2
+ s1, s2 = s1.subnets()
+ elif other.subnet_of(s2):
+ yield s1
+ s1, s2 = s2.subnets()
+ else:
+ # If we got here, there's a bug somewhere.
+ raise AssertionError('Error performing exclusion: '
+ 's1: %s s2: %s other: %s' %
+ (s1, s2, other))
+ if s1 == other:
+ yield s2
+ elif s2 == other:
+ yield s1
+ else:
+ # If we got here, there's a bug somewhere.
+ raise AssertionError('Error performing exclusion: '
+ 's1: %s s2: %s other: %s' %
+ (s1, s2, other))
+
+ def compare_networks(self, other):
+ """Compare two IP objects.
+
+ This is only concerned about the comparison of the integer
+ representation of the network addresses. This means that the
+ host bits aren't considered at all in this method. If you want
+ to compare host bits, you can easily enough do a
+ 'HostA._ip < HostB._ip'
+
+ Args:
+ other: An IP object.
+
+ Returns:
+ If the IP versions of self and other are the same, returns:
+
+ -1 if self < other:
+ eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25')
+ IPv6Network('2001:db8::1000/124') <
+ IPv6Network('2001:db8::2000/124')
+ 0 if self == other
+ eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24')
+ IPv6Network('2001:db8::1000/124') ==
+ IPv6Network('2001:db8::1000/124')
+ 1 if self > other
+ eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25')
+ IPv6Network('2001:db8::2000/124') >
+ IPv6Network('2001:db8::1000/124')
+
+ Raises:
+ TypeError if the IP versions are different.
+
+ """
+ # does this need to raise a ValueError?
+ if self._version != other._version:
+ raise TypeError('%s and %s are not of the same type' % (
+ self, other))
+ # self._version == other._version below here:
+ if self.network_address < other.network_address:
+ return -1
+ if self.network_address > other.network_address:
+ return 1
+ # self.network_address == other.network_address below here:
+ if self.netmask < other.netmask:
+ return -1
+ if self.netmask > other.netmask:
+ return 1
+ return 0
+
+ def _get_networks_key(self):
+ """Network-only key function.
+
+ Returns an object that identifies this address' network and
+ netmask. This function is a suitable "key" argument for sorted()
+ and list.sort().
+
+ """
+ return (self._version, self.network_address, self.netmask)
+
+ def subnets(self, prefixlen_diff=1, new_prefix=None):
+ """The subnets which join to make the current subnet.
+
+ In the case that self contains only one IP
+ (self._prefixlen == 32 for IPv4 or self._prefixlen == 128
+ for IPv6), yield an iterator with just ourself.
+
+ Args:
+ prefixlen_diff: An integer, the amount the prefix length
+ should be increased by. This should not be set if
+ new_prefix is also set.
+ new_prefix: The desired new prefix length. This must be a
+ larger number (smaller prefix) than the existing prefix.
+ This should not be set if prefixlen_diff is also set.
+
+ Returns:
+ An iterator of IPv(4|6) objects.
+
+ Raises:
+ ValueError: The prefixlen_diff is too small or too large.
+ OR
+ prefixlen_diff and new_prefix are both set or new_prefix
+ is a smaller number than the current prefix (smaller
+ number means a larger network)
+
+ """
+ if self._prefixlen == self._max_prefixlen:
+ yield self
+ return
+
+ if new_prefix is not None:
+ if new_prefix < self._prefixlen:
+ raise ValueError('new prefix must be longer')
+ if prefixlen_diff != 1:
+ raise ValueError('cannot set prefixlen_diff and new_prefix')
+ prefixlen_diff = new_prefix - self._prefixlen
+
+ if prefixlen_diff < 0:
+ raise ValueError('prefix length diff must be > 0')
+ new_prefixlen = self._prefixlen + prefixlen_diff
+
+ if new_prefixlen > self._max_prefixlen:
+ raise ValueError(
+ 'prefix length diff %d is invalid for netblock %s' % (
+ new_prefixlen, self))
+
+ start = int(self.network_address)
+ end = int(self.broadcast_address) + 1
+ step = (int(self.hostmask) + 1) >> prefixlen_diff
+ for new_addr in _compat_range(start, end, step):
+ current = self.__class__((new_addr, new_prefixlen))
+ yield current
+
+ def supernet(self, prefixlen_diff=1, new_prefix=None):
+ """The supernet containing the current network.
+
+ Args:
+ prefixlen_diff: An integer, the amount the prefix length of
+ the network should be decreased by. For example, given a
+ /24 network and a prefixlen_diff of 3, a supernet with a
+ /21 netmask is returned.
+
+ Returns:
+ An IPv4 network object.
+
+ Raises:
+ ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have
+ a negative prefix length.
+ OR
+ If prefixlen_diff and new_prefix are both set or new_prefix is a
+ larger number than the current prefix (larger number means a
+ smaller network)
+
+ """
+ if self._prefixlen == 0:
+ return self
+
+ if new_prefix is not None:
+ if new_prefix > self._prefixlen:
+ raise ValueError('new prefix must be shorter')
+ if prefixlen_diff != 1:
+ raise ValueError('cannot set prefixlen_diff and new_prefix')
+ prefixlen_diff = self._prefixlen - new_prefix
+
+ new_prefixlen = self.prefixlen - prefixlen_diff
+ if new_prefixlen < 0:
+ raise ValueError(
+ 'current prefixlen is %d, cannot have a prefixlen_diff of %d' %
+ (self.prefixlen, prefixlen_diff))
+ return self.__class__((
+ int(self.network_address) & (int(self.netmask) << prefixlen_diff),
+ new_prefixlen))
+
+ @property
+ def is_multicast(self):
+ """Test if the address is reserved for multicast use.
+
+ Returns:
+ A boolean, True if the address is a multicast address.
+ See RFC 2373 2.7 for details.
+
+ """
+ return (self.network_address.is_multicast and
+ self.broadcast_address.is_multicast)
+
+ @staticmethod
+ def _is_subnet_of(a, b):
+ try:
+ # Always false if one is v4 and the other is v6.
+ if a._version != b._version:
+ raise TypeError("%s and %s are not of the same version" % (a, b))
+ return (b.network_address <= a.network_address and
+ b.broadcast_address >= a.broadcast_address)
+ except AttributeError:
+ raise TypeError("Unable to test subnet containment "
+ "between %s and %s" % (a, b))
+
+ def subnet_of(self, other):
+ """Return True if this network is a subnet of other."""
+ return self._is_subnet_of(self, other)
+
+ def supernet_of(self, other):
+ """Return True if this network is a supernet of other."""
+ return self._is_subnet_of(other, self)
+
+ @property
+ def is_reserved(self):
+ """Test if the address is otherwise IETF reserved.
+
+ Returns:
+ A boolean, True if the address is within one of the
+ reserved IPv6 Network ranges.
+
+ """
+ return (self.network_address.is_reserved and
+ self.broadcast_address.is_reserved)
+
+ @property
+ def is_link_local(self):
+ """Test if the address is reserved for link-local.
+
+ Returns:
+ A boolean, True if the address is reserved per RFC 4291.
+
+ """
+ return (self.network_address.is_link_local and
+ self.broadcast_address.is_link_local)
+
+ @property
+ def is_private(self):
+ """Test if this address is allocated for private networks.
+
+ Returns:
+ A boolean, True if the address is reserved per
+ iana-ipv4-special-registry or iana-ipv6-special-registry.
+
+ """
+ return (self.network_address.is_private and
+ self.broadcast_address.is_private)
+
+ @property
+ def is_global(self):
+ """Test if this address is allocated for public networks.
+
+ Returns:
+ A boolean, True if the address is not reserved per
+ iana-ipv4-special-registry or iana-ipv6-special-registry.
+
+ """
+ return not self.is_private
+
+ @property
+ def is_unspecified(self):
+ """Test if the address is unspecified.
+
+ Returns:
+ A boolean, True if this is the unspecified address as defined in
+ RFC 2373 2.5.2.
+
+ """
+ return (self.network_address.is_unspecified and
+ self.broadcast_address.is_unspecified)
+
+ @property
+ def is_loopback(self):
+ """Test if the address is a loopback address.
+
+ Returns:
+ A boolean, True if the address is a loopback address as defined in
+ RFC 2373 2.5.3.
+
+ """
+ return (self.network_address.is_loopback and
+ self.broadcast_address.is_loopback)
+
+
+class _BaseV4(object):
+
+ """Base IPv4 object.
+
+ The following methods are used by IPv4 objects in both single IP
+ addresses and networks.
+
+ """
+
+ __slots__ = ()
+ _version = 4
+ # Equivalent to 255.255.255.255 or 32 bits of 1's.
+ _ALL_ONES = (2 ** IPV4LENGTH) - 1
+ _DECIMAL_DIGITS = frozenset('0123456789')
+
+ # the valid octets for host and netmasks. only useful for IPv4.
+ _valid_mask_octets = frozenset([255, 254, 252, 248, 240, 224, 192, 128, 0])
+
+ _max_prefixlen = IPV4LENGTH
+ # There are only a handful of valid v4 netmasks, so we cache them all
+ # when constructed (see _make_netmask()).
+ _netmask_cache = {}
+
+ def _explode_shorthand_ip_string(self):
+ return _compat_str(self)
+
+ @classmethod
+ def _make_netmask(cls, arg):
+ """Make a (netmask, prefix_len) tuple from the given argument.
+
+ Argument can be:
+ - an integer (the prefix length)
+ - a string representing the prefix length (e.g. "24")
+ - a string representing the prefix netmask (e.g. "255.255.255.0")
+ """
+ if arg not in cls._netmask_cache:
+ if isinstance(arg, _compat_int_types):
+ prefixlen = arg
+ else:
+ try:
+ # Check for a netmask in prefix length form
+ prefixlen = cls._prefix_from_prefix_string(arg)
+ except NetmaskValueError:
+ # Check for a netmask or hostmask in dotted-quad form.
+ # This may raise NetmaskValueError.
+ prefixlen = cls._prefix_from_ip_string(arg)
+ netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen))
+ cls._netmask_cache[arg] = netmask, prefixlen
+ return cls._netmask_cache[arg]
+
+ @classmethod
+ def _ip_int_from_string(cls, ip_str):
+ """Turn the given IP string into an integer for comparison.
+
+ Args:
+ ip_str: A string, the IP ip_str.
+
+ Returns:
+ The IP ip_str as an integer.
+
+ Raises:
+ AddressValueError: if ip_str isn't a valid IPv4 Address.
+
+ """
+ if not ip_str:
+ raise AddressValueError('Address cannot be empty')
+
+ octets = ip_str.split('.')
+ if len(octets) != 4:
+ raise AddressValueError("Expected 4 octets in %r" % ip_str)
+
+ try:
+ return _compat_int_from_byte_vals(
+ map(cls._parse_octet, octets), 'big')
+ except ValueError as exc:
+ raise AddressValueError("%s in %r" % (exc, ip_str))
+
+ @classmethod
+ def _parse_octet(cls, octet_str):
+ """Convert a decimal octet into an integer.
+
+ Args:
+ octet_str: A string, the number to parse.
+
+ Returns:
+ The octet as an integer.
+
+ Raises:
+ ValueError: if the octet isn't strictly a decimal from [0..255].
+
+ """
+ if not octet_str:
+ raise ValueError("Empty octet not permitted")
+ # Whitelist the characters, since int() allows a lot of bizarre stuff.
+ if not cls._DECIMAL_DIGITS.issuperset(octet_str):
+ msg = "Only decimal digits permitted in %r"
+ raise ValueError(msg % octet_str)
+ # We do the length check second, since the invalid character error
+ # is likely to be more informative for the user
+ if len(octet_str) > 3:
+ msg = "At most 3 characters permitted in %r"
+ raise ValueError(msg % octet_str)
+ # Convert to integer (we know digits are legal)
+ octet_int = int(octet_str, 10)
+ # Any octets that look like they *might* be written in octal,
+ # and which don't look exactly the same in both octal and
+ # decimal are rejected as ambiguous
+ if octet_int > 7 and octet_str[0] == '0':
+ msg = "Ambiguous (octal/decimal) value in %r not permitted"
+ raise ValueError(msg % octet_str)
+ if octet_int > 255:
+ raise ValueError("Octet %d (> 255) not permitted" % octet_int)
+ return octet_int
+
+ @classmethod
+ def _string_from_ip_int(cls, ip_int):
+ """Turns a 32-bit integer into dotted decimal notation.
+
+ Args:
+ ip_int: An integer, the IP address.
+
+ Returns:
+ The IP address as a string in dotted decimal notation.
+
+ """
+ return '.'.join(_compat_str(struct.unpack(b'!B', b)[0]
+ if isinstance(b, bytes)
+ else b)
+ for b in _compat_to_bytes(ip_int, 4, 'big'))
+
+ def _is_hostmask(self, ip_str):
+ """Test if the IP string is a hostmask (rather than a netmask).
+
+ Args:
+ ip_str: A string, the potential hostmask.
+
+ Returns:
+ A boolean, True if the IP string is a hostmask.
+
+ """
+ bits = ip_str.split('.')
+ try:
+ parts = [x for x in map(int, bits) if x in self._valid_mask_octets]
+ except ValueError:
+ return False
+ if len(parts) != len(bits):
+ return False
+ if parts[0] < parts[-1]:
+ return True
+ return False
+
+ def _reverse_pointer(self):
+ """Return the reverse DNS pointer name for the IPv4 address.
+
+ This implements the method described in RFC1035 3.5.
+
+ """
+ reverse_octets = _compat_str(self).split('.')[::-1]
+ return '.'.join(reverse_octets) + '.in-addr.arpa'
+
+ @property
+ def max_prefixlen(self):
+ return self._max_prefixlen
+
+ @property
+ def version(self):
+ return self._version
+
+
+class IPv4Address(_BaseV4, _BaseAddress):
+
+ """Represent and manipulate single IPv4 Addresses."""
+
+ __slots__ = ('_ip', '__weakref__')
+
+ def __init__(self, address):
+
+ """
+ Args:
+ address: A string or integer representing the IP
+
+ Additionally, an integer can be passed, so
+ IPv4Address('192.0.2.1') == IPv4Address(3221225985).
+ or, more generally
+ IPv4Address(int(IPv4Address('192.0.2.1'))) ==
+ IPv4Address('192.0.2.1')
+
+ Raises:
+ AddressValueError: If ipaddress isn't a valid IPv4 address.
+
+ """
+ # Efficient constructor from integer.
+ if isinstance(address, _compat_int_types):
+ self._check_int_address(address)
+ self._ip = address
+ return
+
+ # Constructing from a packed address
+ if isinstance(address, bytes):
+ self._check_packed_address(address, 4)
+ bvs = _compat_bytes_to_byte_vals(address)
+ self._ip = _compat_int_from_byte_vals(bvs, 'big')
+ return
+
+ # Assume input argument to be string or any object representation
+ # which converts into a formatted IP string.
+ addr_str = _compat_str(address)
+ if '/' in addr_str:
+ raise AddressValueError("Unexpected '/' in %r" % address)
+ self._ip = self._ip_int_from_string(addr_str)
+
+ @property
+ def packed(self):
+ """The binary representation of this address."""
+ return v4_int_to_packed(self._ip)
+
+ @property
+ def is_reserved(self):
+ """Test if the address is otherwise IETF reserved.
+
+ Returns:
+ A boolean, True if the address is within the
+ reserved IPv4 Network range.
+
+ """
+ return self in self._constants._reserved_network
+
+ @property
+ def is_private(self):
+ """Test if this address is allocated for private networks.
+
+ Returns:
+ A boolean, True if the address is reserved per
+ iana-ipv4-special-registry.
+
+ """
+ return any(self in net for net in self._constants._private_networks)
+
+ @property
+ def is_global(self):
+ return (
+ self not in self._constants._public_network and
+ not self.is_private)
+
+ @property
+ def is_multicast(self):
+ """Test if the address is reserved for multicast use.
+
+ Returns:
+ A boolean, True if the address is multicast.
+ See RFC 3171 for details.
+
+ """
+ return self in self._constants._multicast_network
+
+ @property
+ def is_unspecified(self):
+ """Test if the address is unspecified.
+
+ Returns:
+ A boolean, True if this is the unspecified address as defined in
+ RFC 5735 3.
+
+ """
+ return self == self._constants._unspecified_address
+
+ @property
+ def is_loopback(self):
+ """Test if the address is a loopback address.
+
+ Returns:
+ A boolean, True if the address is a loopback per RFC 3330.
+
+ """
+ return self in self._constants._loopback_network
+
+ @property
+ def is_link_local(self):
+ """Test if the address is reserved for link-local.
+
+ Returns:
+ A boolean, True if the address is link-local per RFC 3927.
+
+ """
+ return self in self._constants._linklocal_network
+
+
+class IPv4Interface(IPv4Address):
+
+ def __init__(self, address):
+ if isinstance(address, (bytes, _compat_int_types)):
+ IPv4Address.__init__(self, address)
+ self.network = IPv4Network(self._ip)
+ self._prefixlen = self._max_prefixlen
+ return
+
+ if isinstance(address, tuple):
+ IPv4Address.__init__(self, address[0])
+ if len(address) > 1:
+ self._prefixlen = int(address[1])
+ else:
+ self._prefixlen = self._max_prefixlen
+
+ self.network = IPv4Network(address, strict=False)
+ self.netmask = self.network.netmask
+ self.hostmask = self.network.hostmask
+ return
+
+ addr = _split_optional_netmask(address)
+ IPv4Address.__init__(self, addr[0])
+
+ self.network = IPv4Network(address, strict=False)
+ self._prefixlen = self.network._prefixlen
+
+ self.netmask = self.network.netmask
+ self.hostmask = self.network.hostmask
+
+ def __str__(self):
+ return '%s/%d' % (self._string_from_ip_int(self._ip),
+ self.network.prefixlen)
+
+ def __eq__(self, other):
+ address_equal = IPv4Address.__eq__(self, other)
+ if not address_equal or address_equal is NotImplemented:
+ return address_equal
+ try:
+ return self.network == other.network
+ except AttributeError:
+ # An interface with an associated network is NOT the
+ # same as an unassociated address. That's why the hash
+ # takes the extra info into account.
+ return False
+
+ def __lt__(self, other):
+ address_less = IPv4Address.__lt__(self, other)
+ if address_less is NotImplemented:
+ return NotImplemented
+ try:
+ return (self.network < other.network or
+ self.network == other.network and address_less)
+ except AttributeError:
+ # We *do* allow addresses and interfaces to be sorted. The
+ # unassociated address is considered less than all interfaces.
+ return False
+
+ def __hash__(self):
+ return self._ip ^ self._prefixlen ^ int(self.network.network_address)
+
+ __reduce__ = _IPAddressBase.__reduce__
+
+ @property
+ def ip(self):
+ return IPv4Address(self._ip)
+
+ @property
+ def with_prefixlen(self):
+ return '%s/%s' % (self._string_from_ip_int(self._ip),
+ self._prefixlen)
+
+ @property
+ def with_netmask(self):
+ return '%s/%s' % (self._string_from_ip_int(self._ip),
+ self.netmask)
+
+ @property
+ def with_hostmask(self):
+ return '%s/%s' % (self._string_from_ip_int(self._ip),
+ self.hostmask)
+
+
+class IPv4Network(_BaseV4, _BaseNetwork):
+
+ """This class represents and manipulates 32-bit IPv4 network + addresses..
+
+ Attributes: [examples for IPv4Network('192.0.2.0/27')]
+ .network_address: IPv4Address('192.0.2.0')
+ .hostmask: IPv4Address('0.0.0.31')
+ .broadcast_address: IPv4Address('192.0.2.32')
+ .netmask: IPv4Address('255.255.255.224')
+ .prefixlen: 27
+
+ """
+ # Class to use when creating address objects
+ _address_class = IPv4Address
+
+ def __init__(self, address, strict=True):
+
+ """Instantiate a new IPv4 network object.
+
+ Args:
+ address: A string or integer representing the IP [& network].
+ '192.0.2.0/24'
+ '192.0.2.0/255.255.255.0'
+ '192.0.0.2/0.0.0.255'
+ are all functionally the same in IPv4. Similarly,
+ '192.0.2.1'
+ '192.0.2.1/255.255.255.255'
+ '192.0.2.1/32'
+ are also functionally equivalent. That is to say, failing to
+ provide a subnetmask will create an object with a mask of /32.
+
+ If the mask (portion after the / in the argument) is given in
+ dotted quad form, it is treated as a netmask if it starts with a
+ non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it
+ starts with a zero field (e.g. 0.255.255.255 == /8), with the
+ single exception of an all-zero mask which is treated as a
+ netmask == /0. If no mask is given, a default of /32 is used.
+
+ Additionally, an integer can be passed, so
+ IPv4Network('192.0.2.1') == IPv4Network(3221225985)
+ or, more generally
+ IPv4Interface(int(IPv4Interface('192.0.2.1'))) ==
+ IPv4Interface('192.0.2.1')
+
+ Raises:
+ AddressValueError: If ipaddress isn't a valid IPv4 address.
+ NetmaskValueError: If the netmask isn't valid for
+ an IPv4 address.
+ ValueError: If strict is True and a network address is not
+ supplied.
+
+ """
+ _BaseNetwork.__init__(self, address)
+
+ # Constructing from a packed address or integer
+ if isinstance(address, (_compat_int_types, bytes)):
+ self.network_address = IPv4Address(address)
+ self.netmask, self._prefixlen = self._make_netmask(
+ self._max_prefixlen)
+ # fixme: address/network test here.
+ return
+
+ if isinstance(address, tuple):
+ if len(address) > 1:
+ arg = address[1]
+ else:
+ # We weren't given an address[1]
+ arg = self._max_prefixlen
+ self.network_address = IPv4Address(address[0])
+ self.netmask, self._prefixlen = self._make_netmask(arg)
+ packed = int(self.network_address)
+ if packed & int(self.netmask) != packed:
+ if strict:
+ raise ValueError('%s has host bits set' % self)
+ else:
+ self.network_address = IPv4Address(packed &
+ int(self.netmask))
+ return
+
+ # Assume input argument to be string or any object representation
+ # which converts into a formatted IP prefix string.
+ addr = _split_optional_netmask(address)
+ self.network_address = IPv4Address(self._ip_int_from_string(addr[0]))
+
+ if len(addr) == 2:
+ arg = addr[1]
+ else:
+ arg = self._max_prefixlen
+ self.netmask, self._prefixlen = self._make_netmask(arg)
+
+ if strict:
+ if (IPv4Address(int(self.network_address) & int(self.netmask)) !=
+ self.network_address):
+ raise ValueError('%s has host bits set' % self)
+ self.network_address = IPv4Address(int(self.network_address) &
+ int(self.netmask))
+
+ if self._prefixlen == (self._max_prefixlen - 1):
+ self.hosts = self.__iter__
+
+ @property
+ def is_global(self):
+ """Test if this address is allocated for public networks.
+
+ Returns:
+ A boolean, True if the address is not reserved per
+ iana-ipv4-special-registry.
+
+ """
+ return (not (self.network_address in IPv4Network('100.64.0.0/10') and
+ self.broadcast_address in IPv4Network('100.64.0.0/10')) and
+ not self.is_private)
+
+
+class _IPv4Constants(object):
+
+ _linklocal_network = IPv4Network('169.254.0.0/16')
+
+ _loopback_network = IPv4Network('127.0.0.0/8')
+
+ _multicast_network = IPv4Network('224.0.0.0/4')
+
+ _public_network = IPv4Network('100.64.0.0/10')
+
+ _private_networks = [
+ IPv4Network('0.0.0.0/8'),
+ IPv4Network('10.0.0.0/8'),
+ IPv4Network('127.0.0.0/8'),
+ IPv4Network('169.254.0.0/16'),
+ IPv4Network('172.16.0.0/12'),
+ IPv4Network('192.0.0.0/29'),
+ IPv4Network('192.0.0.170/31'),
+ IPv4Network('192.0.2.0/24'),
+ IPv4Network('192.168.0.0/16'),
+ IPv4Network('198.18.0.0/15'),
+ IPv4Network('198.51.100.0/24'),
+ IPv4Network('203.0.113.0/24'),
+ IPv4Network('240.0.0.0/4'),
+ IPv4Network('255.255.255.255/32'),
+ ]
+
+ _reserved_network = IPv4Network('240.0.0.0/4')
+
+ _unspecified_address = IPv4Address('0.0.0.0')
+
+
+IPv4Address._constants = _IPv4Constants
+
+
+class _BaseV6(object):
+
+ """Base IPv6 object.
+
+ The following methods are used by IPv6 objects in both single IP
+ addresses and networks.
+
+ """
+
+ __slots__ = ()
+ _version = 6
+ _ALL_ONES = (2 ** IPV6LENGTH) - 1
+ _HEXTET_COUNT = 8
+ _HEX_DIGITS = frozenset('0123456789ABCDEFabcdef')
+ _max_prefixlen = IPV6LENGTH
+
+ # There are only a bunch of valid v6 netmasks, so we cache them all
+ # when constructed (see _make_netmask()).
+ _netmask_cache = {}
+
+ @classmethod
+ def _make_netmask(cls, arg):
+ """Make a (netmask, prefix_len) tuple from the given argument.
+
+ Argument can be:
+ - an integer (the prefix length)
+ - a string representing the prefix length (e.g. "24")
+ - a string representing the prefix netmask (e.g. "255.255.255.0")
+ """
+ if arg not in cls._netmask_cache:
+ if isinstance(arg, _compat_int_types):
+ prefixlen = arg
+ else:
+ prefixlen = cls._prefix_from_prefix_string(arg)
+ netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen))
+ cls._netmask_cache[arg] = netmask, prefixlen
+ return cls._netmask_cache[arg]
+
+ @classmethod
+ def _ip_int_from_string(cls, ip_str):
+ """Turn an IPv6 ip_str into an integer.
+
+ Args:
+ ip_str: A string, the IPv6 ip_str.
+
+ Returns:
+ An int, the IPv6 address
+
+ Raises:
+ AddressValueError: if ip_str isn't a valid IPv6 Address.
+
+ """
+ if not ip_str:
+ raise AddressValueError('Address cannot be empty')
+
+ parts = ip_str.split(':')
+
+ # An IPv6 address needs at least 2 colons (3 parts).
+ _min_parts = 3
+ if len(parts) < _min_parts:
+ msg = "At least %d parts expected in %r" % (_min_parts, ip_str)
+ raise AddressValueError(msg)
+
+ # If the address has an IPv4-style suffix, convert it to hexadecimal.
+ if '.' in parts[-1]:
+ try:
+ ipv4_int = IPv4Address(parts.pop())._ip
+ except AddressValueError as exc:
+ raise AddressValueError("%s in %r" % (exc, ip_str))
+ parts.append('%x' % ((ipv4_int >> 16) & 0xFFFF))
+ parts.append('%x' % (ipv4_int & 0xFFFF))
+
+ # An IPv6 address can't have more than 8 colons (9 parts).
+ # The extra colon comes from using the "::" notation for a single
+ # leading or trailing zero part.
+ _max_parts = cls._HEXTET_COUNT + 1
+ if len(parts) > _max_parts:
+ msg = "At most %d colons permitted in %r" % (
+ _max_parts - 1, ip_str)
+ raise AddressValueError(msg)
+
+ # Disregarding the endpoints, find '::' with nothing in between.
+ # This indicates that a run of zeroes has been skipped.
+ skip_index = None
+ for i in _compat_range(1, len(parts) - 1):
+ if not parts[i]:
+ if skip_index is not None:
+ # Can't have more than one '::'
+ msg = "At most one '::' permitted in %r" % ip_str
+ raise AddressValueError(msg)
+ skip_index = i
+
+ # parts_hi is the number of parts to copy from above/before the '::'
+ # parts_lo is the number of parts to copy from below/after the '::'
+ if skip_index is not None:
+ # If we found a '::', then check if it also covers the endpoints.
+ parts_hi = skip_index
+ parts_lo = len(parts) - skip_index - 1
+ if not parts[0]:
+ parts_hi -= 1
+ if parts_hi:
+ msg = "Leading ':' only permitted as part of '::' in %r"
+ raise AddressValueError(msg % ip_str) # ^: requires ^::
+ if not parts[-1]:
+ parts_lo -= 1
+ if parts_lo:
+ msg = "Trailing ':' only permitted as part of '::' in %r"
+ raise AddressValueError(msg % ip_str) # :$ requires ::$
+ parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo)
+ if parts_skipped < 1:
+ msg = "Expected at most %d other parts with '::' in %r"
+ raise AddressValueError(msg % (cls._HEXTET_COUNT - 1, ip_str))
+ else:
+ # Otherwise, allocate the entire address to parts_hi. The
+ # endpoints could still be empty, but _parse_hextet() will check
+ # for that.
+ if len(parts) != cls._HEXTET_COUNT:
+ msg = "Exactly %d parts expected without '::' in %r"
+ raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str))
+ if not parts[0]:
+ msg = "Leading ':' only permitted as part of '::' in %r"
+ raise AddressValueError(msg % ip_str) # ^: requires ^::
+ if not parts[-1]:
+ msg = "Trailing ':' only permitted as part of '::' in %r"
+ raise AddressValueError(msg % ip_str) # :$ requires ::$
+ parts_hi = len(parts)
+ parts_lo = 0
+ parts_skipped = 0
+
+ try:
+ # Now, parse the hextets into a 128-bit integer.
+ ip_int = 0
+ for i in range(parts_hi):
+ ip_int <<= 16
+ ip_int |= cls._parse_hextet(parts[i])
+ ip_int <<= 16 * parts_skipped
+ for i in range(-parts_lo, 0):
+ ip_int <<= 16
+ ip_int |= cls._parse_hextet(parts[i])
+ return ip_int
+ except ValueError as exc:
+ raise AddressValueError("%s in %r" % (exc, ip_str))
+
+ @classmethod
+ def _parse_hextet(cls, hextet_str):
+ """Convert an IPv6 hextet string into an integer.
+
+ Args:
+ hextet_str: A string, the number to parse.
+
+ Returns:
+ The hextet as an integer.
+
+ Raises:
+ ValueError: if the input isn't strictly a hex number from
+ [0..FFFF].
+
+ """
+ # Whitelist the characters, since int() allows a lot of bizarre stuff.
+ if not cls._HEX_DIGITS.issuperset(hextet_str):
+ raise ValueError("Only hex digits permitted in %r" % hextet_str)
+ # We do the length check second, since the invalid character error
+ # is likely to be more informative for the user
+ if len(hextet_str) > 4:
+ msg = "At most 4 characters permitted in %r"
+ raise ValueError(msg % hextet_str)
+ # Length check means we can skip checking the integer value
+ return int(hextet_str, 16)
+
+ @classmethod
+ def _compress_hextets(cls, hextets):
+ """Compresses a list of hextets.
+
+ Compresses a list of strings, replacing the longest continuous
+ sequence of "0" in the list with "" and adding empty strings at
+ the beginning or at the end of the string such that subsequently
+ calling ":".join(hextets) will produce the compressed version of
+ the IPv6 address.
+
+ Args:
+ hextets: A list of strings, the hextets to compress.
+
+ Returns:
+ A list of strings.
+
+ """
+ best_doublecolon_start = -1
+ best_doublecolon_len = 0
+ doublecolon_start = -1
+ doublecolon_len = 0
+ for index, hextet in enumerate(hextets):
+ if hextet == '0':
+ doublecolon_len += 1
+ if doublecolon_start == -1:
+ # Start of a sequence of zeros.
+ doublecolon_start = index
+ if doublecolon_len > best_doublecolon_len:
+ # This is the longest sequence of zeros so far.
+ best_doublecolon_len = doublecolon_len
+ best_doublecolon_start = doublecolon_start
+ else:
+ doublecolon_len = 0
+ doublecolon_start = -1
+
+ if best_doublecolon_len > 1:
+ best_doublecolon_end = (best_doublecolon_start +
+ best_doublecolon_len)
+ # For zeros at the end of the address.
+ if best_doublecolon_end == len(hextets):
+ hextets += ['']
+ hextets[best_doublecolon_start:best_doublecolon_end] = ['']
+ # For zeros at the beginning of the address.
+ if best_doublecolon_start == 0:
+ hextets = [''] + hextets
+
+ return hextets
+
+ @classmethod
+ def _string_from_ip_int(cls, ip_int=None):
+ """Turns a 128-bit integer into hexadecimal notation.
+
+ Args:
+ ip_int: An integer, the IP address.
+
+ Returns:
+ A string, the hexadecimal representation of the address.
+
+ Raises:
+ ValueError: The address is bigger than 128 bits of all ones.
+
+ """
+ if ip_int is None:
+ ip_int = int(cls._ip)
+
+ if ip_int > cls._ALL_ONES:
+ raise ValueError('IPv6 address is too large')
+
+ hex_str = '%032x' % ip_int
+ hextets = ['%x' % int(hex_str[x:x + 4], 16) for x in range(0, 32, 4)]
+
+ hextets = cls._compress_hextets(hextets)
+ return ':'.join(hextets)
+
+ def _explode_shorthand_ip_string(self):
+ """Expand a shortened IPv6 address.
+
+ Args:
+ ip_str: A string, the IPv6 address.
+
+ Returns:
+ A string, the expanded IPv6 address.
+
+ """
+ if isinstance(self, IPv6Network):
+ ip_str = _compat_str(self.network_address)
+ elif isinstance(self, IPv6Interface):
+ ip_str = _compat_str(self.ip)
+ else:
+ ip_str = _compat_str(self)
+
+ ip_int = self._ip_int_from_string(ip_str)
+ hex_str = '%032x' % ip_int
+ parts = [hex_str[x:x + 4] for x in range(0, 32, 4)]
+ if isinstance(self, (_BaseNetwork, IPv6Interface)):
+ return '%s/%d' % (':'.join(parts), self._prefixlen)
+ return ':'.join(parts)
+
+ def _reverse_pointer(self):
+ """Return the reverse DNS pointer name for the IPv6 address.
+
+ This implements the method described in RFC3596 2.5.
+
+ """
+ reverse_chars = self.exploded[::-1].replace(':', '')
+ return '.'.join(reverse_chars) + '.ip6.arpa'
+
+ @property
+ def max_prefixlen(self):
+ return self._max_prefixlen
+
+ @property
+ def version(self):
+ return self._version
+
+
+class IPv6Address(_BaseV6, _BaseAddress):
+
+ """Represent and manipulate single IPv6 Addresses."""
+
+ __slots__ = ('_ip', '__weakref__')
+
+ def __init__(self, address):
+ """Instantiate a new IPv6 address object.
+
+ Args:
+ address: A string or integer representing the IP
+
+ Additionally, an integer can be passed, so
+ IPv6Address('2001:db8::') ==
+ IPv6Address(42540766411282592856903984951653826560)
+ or, more generally
+ IPv6Address(int(IPv6Address('2001:db8::'))) ==
+ IPv6Address('2001:db8::')
+
+ Raises:
+ AddressValueError: If address isn't a valid IPv6 address.
+
+ """
+ # Efficient constructor from integer.
+ if isinstance(address, _compat_int_types):
+ self._check_int_address(address)
+ self._ip = address
+ return
+
+ # Constructing from a packed address
+ if isinstance(address, bytes):
+ self._check_packed_address(address, 16)
+ bvs = _compat_bytes_to_byte_vals(address)
+ self._ip = _compat_int_from_byte_vals(bvs, 'big')
+ return
+
+ # Assume input argument to be string or any object representation
+ # which converts into a formatted IP string.
+ addr_str = _compat_str(address)
+ if '/' in addr_str:
+ raise AddressValueError("Unexpected '/' in %r" % address)
+ self._ip = self._ip_int_from_string(addr_str)
+
+ @property
+ def packed(self):
+ """The binary representation of this address."""
+ return v6_int_to_packed(self._ip)
+
+ @property
+ def is_multicast(self):
+ """Test if the address is reserved for multicast use.
+
+ Returns:
+ A boolean, True if the address is a multicast address.
+ See RFC 2373 2.7 for details.
+
+ """
+ return self in self._constants._multicast_network
+
+ @property
+ def is_reserved(self):
+ """Test if the address is otherwise IETF reserved.
+
+ Returns:
+ A boolean, True if the address is within one of the
+ reserved IPv6 Network ranges.
+
+ """
+ return any(self in x for x in self._constants._reserved_networks)
+
+ @property
+ def is_link_local(self):
+ """Test if the address is reserved for link-local.
+
+ Returns:
+ A boolean, True if the address is reserved per RFC 4291.
+
+ """
+ return self in self._constants._linklocal_network
+
+ @property
+ def is_site_local(self):
+ """Test if the address is reserved for site-local.
+
+ Note that the site-local address space has been deprecated by RFC 3879.
+ Use is_private to test if this address is in the space of unique local
+ addresses as defined by RFC 4193.
+
+ Returns:
+ A boolean, True if the address is reserved per RFC 3513 2.5.6.
+
+ """
+ return self in self._constants._sitelocal_network
+
+ @property
+ def is_private(self):
+ """Test if this address is allocated for private networks.
+
+ Returns:
+ A boolean, True if the address is reserved per
+ iana-ipv6-special-registry.
+
+ """
+ return any(self in net for net in self._constants._private_networks)
+
+ @property
+ def is_global(self):
+ """Test if this address is allocated for public networks.
+
+ Returns:
+ A boolean, true if the address is not reserved per
+ iana-ipv6-special-registry.
+
+ """
+ return not self.is_private
+
+ @property
+ def is_unspecified(self):
+ """Test if the address is unspecified.
+
+ Returns:
+ A boolean, True if this is the unspecified address as defined in
+ RFC 2373 2.5.2.
+
+ """
+ return self._ip == 0
+
+ @property
+ def is_loopback(self):
+ """Test if the address is a loopback address.
+
+ Returns:
+ A boolean, True if the address is a loopback address as defined in
+ RFC 2373 2.5.3.
+
+ """
+ return self._ip == 1
+
+ @property
+ def ipv4_mapped(self):
+ """Return the IPv4 mapped address.
+
+ Returns:
+ If the IPv6 address is a v4 mapped address, return the
+ IPv4 mapped address. Return None otherwise.
+
+ """
+ if (self._ip >> 32) != 0xFFFF:
+ return None
+ return IPv4Address(self._ip & 0xFFFFFFFF)
+
+ @property
+ def teredo(self):
+ """Tuple of embedded teredo IPs.
+
+ Returns:
+ Tuple of the (server, client) IPs or None if the address
+ doesn't appear to be a teredo address (doesn't start with
+ 2001::/32)
+
+ """
+ if (self._ip >> 96) != 0x20010000:
+ return None
+ return (IPv4Address((self._ip >> 64) & 0xFFFFFFFF),
+ IPv4Address(~self._ip & 0xFFFFFFFF))
+
+ @property
+ def sixtofour(self):
+ """Return the IPv4 6to4 embedded address.
+
+ Returns:
+ The IPv4 6to4-embedded address if present or None if the
+ address doesn't appear to contain a 6to4 embedded address.
+
+ """
+ if (self._ip >> 112) != 0x2002:
+ return None
+ return IPv4Address((self._ip >> 80) & 0xFFFFFFFF)
+
+
+class IPv6Interface(IPv6Address):
+
+ def __init__(self, address):
+ if isinstance(address, (bytes, _compat_int_types)):
+ IPv6Address.__init__(self, address)
+ self.network = IPv6Network(self._ip)
+ self._prefixlen = self._max_prefixlen
+ return
+ if isinstance(address, tuple):
+ IPv6Address.__init__(self, address[0])
+ if len(address) > 1:
+ self._prefixlen = int(address[1])
+ else:
+ self._prefixlen = self._max_prefixlen
+ self.network = IPv6Network(address, strict=False)
+ self.netmask = self.network.netmask
+ self.hostmask = self.network.hostmask
+ return
+
+ addr = _split_optional_netmask(address)
+ IPv6Address.__init__(self, addr[0])
+ self.network = IPv6Network(address, strict=False)
+ self.netmask = self.network.netmask
+ self._prefixlen = self.network._prefixlen
+ self.hostmask = self.network.hostmask
+
+ def __str__(self):
+ return '%s/%d' % (self._string_from_ip_int(self._ip),
+ self.network.prefixlen)
+
+ def __eq__(self, other):
+ address_equal = IPv6Address.__eq__(self, other)
+ if not address_equal or address_equal is NotImplemented:
+ return address_equal
+ try:
+ return self.network == other.network
+ except AttributeError:
+ # An interface with an associated network is NOT the
+ # same as an unassociated address. That's why the hash
+ # takes the extra info into account.
+ return False
+
+ def __lt__(self, other):
+ address_less = IPv6Address.__lt__(self, other)
+ if address_less is NotImplemented:
+ return NotImplemented
+ try:
+ return (self.network < other.network or
+ self.network == other.network and address_less)
+ except AttributeError:
+ # We *do* allow addresses and interfaces to be sorted. The
+ # unassociated address is considered less than all interfaces.
+ return False
+
+ def __hash__(self):
+ return self._ip ^ self._prefixlen ^ int(self.network.network_address)
+
+ __reduce__ = _IPAddressBase.__reduce__
+
+ @property
+ def ip(self):
+ return IPv6Address(self._ip)
+
+ @property
+ def with_prefixlen(self):
+ return '%s/%s' % (self._string_from_ip_int(self._ip),
+ self._prefixlen)
+
+ @property
+ def with_netmask(self):
+ return '%s/%s' % (self._string_from_ip_int(self._ip),
+ self.netmask)
+
+ @property
+ def with_hostmask(self):
+ return '%s/%s' % (self._string_from_ip_int(self._ip),
+ self.hostmask)
+
+ @property
+ def is_unspecified(self):
+ return self._ip == 0 and self.network.is_unspecified
+
+ @property
+ def is_loopback(self):
+ return self._ip == 1 and self.network.is_loopback
+
+
+class IPv6Network(_BaseV6, _BaseNetwork):
+
+ """This class represents and manipulates 128-bit IPv6 networks.
+
+ Attributes: [examples for IPv6('2001:db8::1000/124')]
+ .network_address: IPv6Address('2001:db8::1000')
+ .hostmask: IPv6Address('::f')
+ .broadcast_address: IPv6Address('2001:db8::100f')
+ .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0')
+ .prefixlen: 124
+
+ """
+
+ # Class to use when creating address objects
+ _address_class = IPv6Address
+
+ def __init__(self, address, strict=True):
+ """Instantiate a new IPv6 Network object.
+
+ Args:
+ address: A string or integer representing the IPv6 network or the
+ IP and prefix/netmask.
+ '2001:db8::/128'
+ '2001:db8:0000:0000:0000:0000:0000:0000/128'
+ '2001:db8::'
+ are all functionally the same in IPv6. That is to say,
+ failing to provide a subnetmask will create an object with
+ a mask of /128.
+
+ Additionally, an integer can be passed, so
+ IPv6Network('2001:db8::') ==
+ IPv6Network(42540766411282592856903984951653826560)
+ or, more generally
+ IPv6Network(int(IPv6Network('2001:db8::'))) ==
+ IPv6Network('2001:db8::')
+
+ strict: A boolean. If true, ensure that we have been passed
+ A true network address, eg, 2001:db8::1000/124 and not an
+ IP address on a network, eg, 2001:db8::1/124.
+
+ Raises:
+ AddressValueError: If address isn't a valid IPv6 address.
+ NetmaskValueError: If the netmask isn't valid for
+ an IPv6 address.
+ ValueError: If strict was True and a network address was not
+ supplied.
+
+ """
+ _BaseNetwork.__init__(self, address)
+
+ # Efficient constructor from integer or packed address
+ if isinstance(address, (bytes, _compat_int_types)):
+ self.network_address = IPv6Address(address)
+ self.netmask, self._prefixlen = self._make_netmask(
+ self._max_prefixlen)
+ return
+
+ if isinstance(address, tuple):
+ if len(address) > 1:
+ arg = address[1]
+ else:
+ arg = self._max_prefixlen
+ self.netmask, self._prefixlen = self._make_netmask(arg)
+ self.network_address = IPv6Address(address[0])
+ packed = int(self.network_address)
+ if packed & int(self.netmask) != packed:
+ if strict:
+ raise ValueError('%s has host bits set' % self)
+ else:
+ self.network_address = IPv6Address(packed &
+ int(self.netmask))
+ return
+
+ # Assume input argument to be string or any object representation
+ # which converts into a formatted IP prefix string.
+ addr = _split_optional_netmask(address)
+
+ self.network_address = IPv6Address(self._ip_int_from_string(addr[0]))
+
+ if len(addr) == 2:
+ arg = addr[1]
+ else:
+ arg = self._max_prefixlen
+ self.netmask, self._prefixlen = self._make_netmask(arg)
+
+ if strict:
+ if (IPv6Address(int(self.network_address) & int(self.netmask)) !=
+ self.network_address):
+ raise ValueError('%s has host bits set' % self)
+ self.network_address = IPv6Address(int(self.network_address) &
+ int(self.netmask))
+
+ if self._prefixlen == (self._max_prefixlen - 1):
+ self.hosts = self.__iter__
+
+ def hosts(self):
+ """Generate Iterator over usable hosts in a network.
+
+ This is like __iter__ except it doesn't return the
+ Subnet-Router anycast address.
+
+ """
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ for x in _compat_range(network + 1, broadcast + 1):
+ yield self._address_class(x)
+
+ @property
+ def is_site_local(self):
+ """Test if the address is reserved for site-local.
+
+ Note that the site-local address space has been deprecated by RFC 3879.
+ Use is_private to test if this address is in the space of unique local
+ addresses as defined by RFC 4193.
+
+ Returns:
+ A boolean, True if the address is reserved per RFC 3513 2.5.6.
+
+ """
+ return (self.network_address.is_site_local and
+ self.broadcast_address.is_site_local)
+
+
+class _IPv6Constants(object):
+
+ _linklocal_network = IPv6Network('fe80::/10')
+
+ _multicast_network = IPv6Network('ff00::/8')
+
+ _private_networks = [
+ IPv6Network('::1/128'),
+ IPv6Network('::/128'),
+ IPv6Network('::ffff:0:0/96'),
+ IPv6Network('100::/64'),
+ IPv6Network('2001::/23'),
+ IPv6Network('2001:2::/48'),
+ IPv6Network('2001:db8::/32'),
+ IPv6Network('2001:10::/28'),
+ IPv6Network('fc00::/7'),
+ IPv6Network('fe80::/10'),
+ ]
+
+ _reserved_networks = [
+ IPv6Network('::/8'), IPv6Network('100::/8'),
+ IPv6Network('200::/7'), IPv6Network('400::/6'),
+ IPv6Network('800::/5'), IPv6Network('1000::/4'),
+ IPv6Network('4000::/3'), IPv6Network('6000::/3'),
+ IPv6Network('8000::/3'), IPv6Network('A000::/3'),
+ IPv6Network('C000::/3'), IPv6Network('E000::/4'),
+ IPv6Network('F000::/5'), IPv6Network('F800::/6'),
+ IPv6Network('FE00::/9'),
+ ]
+
+ _sitelocal_network = IPv6Network('fec0::/10')
+
+
+IPv6Address._constants = _IPv6Constants
diff --git a/test/support/integration/plugins/module_utils/crypto.py b/test/support/integration/plugins/module_utils/crypto.py
new file mode 100644
index 00000000..e67eeff1
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/crypto.py
@@ -0,0 +1,2125 @@
+# -*- coding: utf-8 -*-
+#
+# (c) 2016, Yanis Guenane <yanis+ansible@guenane.org>
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+# ----------------------------------------------------------------------
+# A clearly marked portion of this file is licensed under the BSD license
+# Copyright (c) 2015, 2016 Paul Kehrer (@reaperhulk)
+# Copyright (c) 2017 Fraser Tweedale (@frasertweedale)
+# For more details, search for the function _obj2txt().
+# ---------------------------------------------------------------------
+# A clearly marked portion of this file is extracted from a project that
+# is licensed under the Apache License 2.0
+# Copyright (c) the OpenSSL contributors
+# For more details, search for the function _OID_MAP.
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+import sys
+from distutils.version import LooseVersion
+
+try:
+ import OpenSSL
+ from OpenSSL import crypto
+except ImportError:
+ # An error will be raised in the calling class to let the end
+ # user know that OpenSSL couldn't be found.
+ pass
+
+try:
+ import cryptography
+ from cryptography import x509
+ from cryptography.hazmat.backends import default_backend as cryptography_backend
+ from cryptography.hazmat.primitives.serialization import load_pem_private_key
+ from cryptography.hazmat.primitives import hashes
+ from cryptography.hazmat.primitives import serialization
+ import ipaddress
+
+ # Older versions of cryptography (< 2.1) do not have __hash__ functions for
+ # general name objects (DNSName, IPAddress, ...), while providing overloaded
+ # equality and string representation operations. This makes it impossible to
+ # use them in hash-based data structures such as set or dict. Since we are
+ # actually doing that in openssl_certificate, and potentially in other code,
+ # we need to monkey-patch __hash__ for these classes to make sure our code
+ # works fine.
+ if LooseVersion(cryptography.__version__) < LooseVersion('2.1'):
+ # A very simply hash function which relies on the representation
+ # of an object to be implemented. This is the case since at least
+ # cryptography 1.0, see
+ # https://github.com/pyca/cryptography/commit/7a9abce4bff36c05d26d8d2680303a6f64a0e84f
+ def simple_hash(self):
+ return hash(repr(self))
+
+ # The hash functions for the following types were added for cryptography 2.1:
+ # https://github.com/pyca/cryptography/commit/fbfc36da2a4769045f2373b004ddf0aff906cf38
+ x509.DNSName.__hash__ = simple_hash
+ x509.DirectoryName.__hash__ = simple_hash
+ x509.GeneralName.__hash__ = simple_hash
+ x509.IPAddress.__hash__ = simple_hash
+ x509.OtherName.__hash__ = simple_hash
+ x509.RegisteredID.__hash__ = simple_hash
+
+ if LooseVersion(cryptography.__version__) < LooseVersion('1.2'):
+ # The hash functions for the following types were added for cryptography 1.2:
+ # https://github.com/pyca/cryptography/commit/b642deed88a8696e5f01ce6855ccf89985fc35d0
+ # https://github.com/pyca/cryptography/commit/d1b5681f6db2bde7a14625538bd7907b08dfb486
+ x509.RFC822Name.__hash__ = simple_hash
+ x509.UniformResourceIdentifier.__hash__ = simple_hash
+
+ # Test whether we have support for X25519, X448, Ed25519 and/or Ed448
+ try:
+ import cryptography.hazmat.primitives.asymmetric.x25519
+ CRYPTOGRAPHY_HAS_X25519 = True
+ try:
+ cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.private_bytes
+ CRYPTOGRAPHY_HAS_X25519_FULL = True
+ except AttributeError:
+ CRYPTOGRAPHY_HAS_X25519_FULL = False
+ except ImportError:
+ CRYPTOGRAPHY_HAS_X25519 = False
+ CRYPTOGRAPHY_HAS_X25519_FULL = False
+ try:
+ import cryptography.hazmat.primitives.asymmetric.x448
+ CRYPTOGRAPHY_HAS_X448 = True
+ except ImportError:
+ CRYPTOGRAPHY_HAS_X448 = False
+ try:
+ import cryptography.hazmat.primitives.asymmetric.ed25519
+ CRYPTOGRAPHY_HAS_ED25519 = True
+ except ImportError:
+ CRYPTOGRAPHY_HAS_ED25519 = False
+ try:
+ import cryptography.hazmat.primitives.asymmetric.ed448
+ CRYPTOGRAPHY_HAS_ED448 = True
+ except ImportError:
+ CRYPTOGRAPHY_HAS_ED448 = False
+
+ HAS_CRYPTOGRAPHY = True
+except ImportError:
+ # Error handled in the calling module.
+ CRYPTOGRAPHY_HAS_X25519 = False
+ CRYPTOGRAPHY_HAS_X25519_FULL = False
+ CRYPTOGRAPHY_HAS_X448 = False
+ CRYPTOGRAPHY_HAS_ED25519 = False
+ CRYPTOGRAPHY_HAS_ED448 = False
+ HAS_CRYPTOGRAPHY = False
+
+
+import abc
+import base64
+import binascii
+import datetime
+import errno
+import hashlib
+import os
+import re
+import tempfile
+
+from ansible.module_utils import six
+from ansible.module_utils._text import to_native, to_bytes, to_text
+
+
+class OpenSSLObjectError(Exception):
+ pass
+
+
+class OpenSSLBadPassphraseError(OpenSSLObjectError):
+ pass
+
+
+def get_fingerprint_of_bytes(source):
+ """Generate the fingerprint of the given bytes."""
+
+ fingerprint = {}
+
+ try:
+ algorithms = hashlib.algorithms
+ except AttributeError:
+ try:
+ algorithms = hashlib.algorithms_guaranteed
+ except AttributeError:
+ return None
+
+ for algo in algorithms:
+ f = getattr(hashlib, algo)
+ try:
+ h = f(source)
+ except ValueError:
+ # This can happen for hash algorithms not supported in FIPS mode
+ # (https://github.com/ansible/ansible/issues/67213)
+ continue
+ try:
+ # Certain hash functions have a hexdigest() which expects a length parameter
+ pubkey_digest = h.hexdigest()
+ except TypeError:
+ pubkey_digest = h.hexdigest(32)
+ fingerprint[algo] = ':'.join(pubkey_digest[i:i + 2] for i in range(0, len(pubkey_digest), 2))
+
+ return fingerprint
+
+
+def get_fingerprint(path, passphrase=None, content=None, backend='pyopenssl'):
+ """Generate the fingerprint of the public key. """
+
+ privatekey = load_privatekey(path, passphrase=passphrase, content=content, check_passphrase=False, backend=backend)
+
+ if backend == 'pyopenssl':
+ try:
+ publickey = crypto.dump_publickey(crypto.FILETYPE_ASN1, privatekey)
+ except AttributeError:
+ # If PyOpenSSL < 16.0 crypto.dump_publickey() will fail.
+ try:
+ bio = crypto._new_mem_buf()
+ rc = crypto._lib.i2d_PUBKEY_bio(bio, privatekey._pkey)
+ if rc != 1:
+ crypto._raise_current_error()
+ publickey = crypto._bio_to_string(bio)
+ except AttributeError:
+ # By doing this we prevent the code from raising an error
+ # yet we return no value in the fingerprint hash.
+ return None
+ elif backend == 'cryptography':
+ publickey = privatekey.public_key().public_bytes(
+ serialization.Encoding.DER,
+ serialization.PublicFormat.SubjectPublicKeyInfo
+ )
+
+ return get_fingerprint_of_bytes(publickey)
+
+
+def load_file_if_exists(path, module=None, ignore_errors=False):
+ try:
+ with open(path, 'rb') as f:
+ return f.read()
+ except EnvironmentError as exc:
+ if exc.errno == errno.ENOENT:
+ return None
+ if ignore_errors:
+ return None
+ if module is None:
+ raise
+ module.fail_json('Error while loading {0} - {1}'.format(path, str(exc)))
+ except Exception as exc:
+ if ignore_errors:
+ return None
+ if module is None:
+ raise
+ module.fail_json('Error while loading {0} - {1}'.format(path, str(exc)))
+
+
+def load_privatekey(path, passphrase=None, check_passphrase=True, content=None, backend='pyopenssl'):
+ """Load the specified OpenSSL private key.
+
+ The content can also be specified via content; in that case,
+ this function will not load the key from disk.
+ """
+
+ try:
+ if content is None:
+ with open(path, 'rb') as b_priv_key_fh:
+ priv_key_detail = b_priv_key_fh.read()
+ else:
+ priv_key_detail = content
+
+ if backend == 'pyopenssl':
+
+ # First try: try to load with real passphrase (resp. empty string)
+ # Will work if this is the correct passphrase, or the key is not
+ # password-protected.
+ try:
+ result = crypto.load_privatekey(crypto.FILETYPE_PEM,
+ priv_key_detail,
+ to_bytes(passphrase or ''))
+ except crypto.Error as e:
+ if len(e.args) > 0 and len(e.args[0]) > 0:
+ if e.args[0][0][2] in ('bad decrypt', 'bad password read'):
+ # This happens in case we have the wrong passphrase.
+ if passphrase is not None:
+ raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key!')
+ else:
+ raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!')
+ raise OpenSSLObjectError('Error while deserializing key: {0}'.format(e))
+ if check_passphrase:
+ # Next we want to make sure that the key is actually protected by
+ # a passphrase (in case we did try the empty string before, make
+ # sure that the key is not protected by the empty string)
+ try:
+ crypto.load_privatekey(crypto.FILETYPE_PEM,
+ priv_key_detail,
+ to_bytes('y' if passphrase == 'x' else 'x'))
+ if passphrase is not None:
+ # Since we can load the key without an exception, the
+ # key isn't password-protected
+ raise OpenSSLBadPassphraseError('Passphrase provided, but private key is not password-protected!')
+ except crypto.Error as e:
+ if passphrase is None and len(e.args) > 0 and len(e.args[0]) > 0:
+ if e.args[0][0][2] in ('bad decrypt', 'bad password read'):
+ # The key is obviously protected by the empty string.
+ # Don't do this at home (if it's possible at all)...
+ raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!')
+ elif backend == 'cryptography':
+ try:
+ result = load_pem_private_key(priv_key_detail,
+ None if passphrase is None else to_bytes(passphrase),
+ cryptography_backend())
+ except TypeError as dummy:
+ raise OpenSSLBadPassphraseError('Wrong or empty passphrase provided for private key')
+ except ValueError as dummy:
+ raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key')
+
+ return result
+ except (IOError, OSError) as exc:
+ raise OpenSSLObjectError(exc)
+
+
+def load_certificate(path, content=None, backend='pyopenssl'):
+ """Load the specified certificate."""
+
+ try:
+ if content is None:
+ with open(path, 'rb') as cert_fh:
+ cert_content = cert_fh.read()
+ else:
+ cert_content = content
+ if backend == 'pyopenssl':
+ return crypto.load_certificate(crypto.FILETYPE_PEM, cert_content)
+ elif backend == 'cryptography':
+ return x509.load_pem_x509_certificate(cert_content, cryptography_backend())
+ except (IOError, OSError) as exc:
+ raise OpenSSLObjectError(exc)
+
+
+def load_certificate_request(path, content=None, backend='pyopenssl'):
+ """Load the specified certificate signing request."""
+ try:
+ if content is None:
+ with open(path, 'rb') as csr_fh:
+ csr_content = csr_fh.read()
+ else:
+ csr_content = content
+ except (IOError, OSError) as exc:
+ raise OpenSSLObjectError(exc)
+ if backend == 'pyopenssl':
+ return crypto.load_certificate_request(crypto.FILETYPE_PEM, csr_content)
+ elif backend == 'cryptography':
+ return x509.load_pem_x509_csr(csr_content, cryptography_backend())
+
+
+def parse_name_field(input_dict):
+ """Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
+
+ result = []
+ for key in input_dict:
+ if isinstance(input_dict[key], list):
+ for entry in input_dict[key]:
+ result.append((key, entry))
+ else:
+ result.append((key, input_dict[key]))
+ return result
+
+
+def convert_relative_to_datetime(relative_time_string):
+ """Get a datetime.datetime or None from a string in the time format described in sshd_config(5)"""
+
+ parsed_result = re.match(
+ r"^(?P<prefix>[+-])((?P<weeks>\d+)[wW])?((?P<days>\d+)[dD])?((?P<hours>\d+)[hH])?((?P<minutes>\d+)[mM])?((?P<seconds>\d+)[sS]?)?$",
+ relative_time_string)
+
+ if parsed_result is None or len(relative_time_string) == 1:
+ # not matched or only a single "+" or "-"
+ return None
+
+ offset = datetime.timedelta(0)
+ if parsed_result.group("weeks") is not None:
+ offset += datetime.timedelta(weeks=int(parsed_result.group("weeks")))
+ if parsed_result.group("days") is not None:
+ offset += datetime.timedelta(days=int(parsed_result.group("days")))
+ if parsed_result.group("hours") is not None:
+ offset += datetime.timedelta(hours=int(parsed_result.group("hours")))
+ if parsed_result.group("minutes") is not None:
+ offset += datetime.timedelta(
+ minutes=int(parsed_result.group("minutes")))
+ if parsed_result.group("seconds") is not None:
+ offset += datetime.timedelta(
+ seconds=int(parsed_result.group("seconds")))
+
+ if parsed_result.group("prefix") == "+":
+ return datetime.datetime.utcnow() + offset
+ else:
+ return datetime.datetime.utcnow() - offset
+
+
+def get_relative_time_option(input_string, input_name, backend='cryptography'):
+ """Return an absolute timespec if a relative timespec or an ASN1 formatted
+ string is provided.
+
+ The return value will be a datetime object for the cryptography backend,
+ and a ASN1 formatted string for the pyopenssl backend."""
+ result = to_native(input_string)
+ if result is None:
+ raise OpenSSLObjectError(
+ 'The timespec "%s" for %s is not valid' %
+ input_string, input_name)
+ # Relative time
+ if result.startswith("+") or result.startswith("-"):
+ result_datetime = convert_relative_to_datetime(result)
+ if backend == 'pyopenssl':
+ return result_datetime.strftime("%Y%m%d%H%M%SZ")
+ elif backend == 'cryptography':
+ return result_datetime
+ # Absolute time
+ if backend == 'pyopenssl':
+ return input_string
+ elif backend == 'cryptography':
+ for date_fmt in ['%Y%m%d%H%M%SZ', '%Y%m%d%H%MZ', '%Y%m%d%H%M%S%z', '%Y%m%d%H%M%z']:
+ try:
+ return datetime.datetime.strptime(result, date_fmt)
+ except ValueError:
+ pass
+
+ raise OpenSSLObjectError(
+ 'The time spec "%s" for %s is invalid' %
+ (input_string, input_name)
+ )
+
+
+def select_message_digest(digest_string):
+ digest = None
+ if digest_string == 'sha256':
+ digest = hashes.SHA256()
+ elif digest_string == 'sha384':
+ digest = hashes.SHA384()
+ elif digest_string == 'sha512':
+ digest = hashes.SHA512()
+ elif digest_string == 'sha1':
+ digest = hashes.SHA1()
+ elif digest_string == 'md5':
+ digest = hashes.MD5()
+ return digest
+
+
+def write_file(module, content, default_mode=None, path=None):
+ '''
+ Writes content into destination file as securely as possible.
+ Uses file arguments from module.
+ '''
+ # Find out parameters for file
+ file_args = module.load_file_common_arguments(module.params, path=path)
+ if file_args['mode'] is None:
+ file_args['mode'] = default_mode
+ # Create tempfile name
+ tmp_fd, tmp_name = tempfile.mkstemp(prefix=b'.ansible_tmp')
+ try:
+ os.close(tmp_fd)
+ except Exception as dummy:
+ pass
+ module.add_cleanup_file(tmp_name) # if we fail, let Ansible try to remove the file
+ try:
+ try:
+ # Create tempfile
+ file = os.open(tmp_name, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
+ os.write(file, content)
+ os.close(file)
+ except Exception as e:
+ try:
+ os.remove(tmp_name)
+ except Exception as dummy:
+ pass
+ module.fail_json(msg='Error while writing result into temporary file: {0}'.format(e))
+ # Update destination to wanted permissions
+ if os.path.exists(file_args['path']):
+ module.set_fs_attributes_if_different(file_args, False)
+ # Move tempfile to final destination
+ module.atomic_move(tmp_name, file_args['path'])
+ # Try to update permissions again
+ module.set_fs_attributes_if_different(file_args, False)
+ except Exception as e:
+ try:
+ os.remove(tmp_name)
+ except Exception as dummy:
+ pass
+ module.fail_json(msg='Error while writing result: {0}'.format(e))
+
+
+@six.add_metaclass(abc.ABCMeta)
+class OpenSSLObject(object):
+
+ def __init__(self, path, state, force, check_mode):
+ self.path = path
+ self.state = state
+ self.force = force
+ self.name = os.path.basename(path)
+ self.changed = False
+ self.check_mode = check_mode
+
+ def check(self, module, perms_required=True):
+ """Ensure the resource is in its desired state."""
+
+ def _check_state():
+ return os.path.exists(self.path)
+
+ def _check_perms(module):
+ file_args = module.load_file_common_arguments(module.params)
+ return not module.set_fs_attributes_if_different(file_args, False)
+
+ if not perms_required:
+ return _check_state()
+
+ return _check_state() and _check_perms(module)
+
+ @abc.abstractmethod
+ def dump(self):
+ """Serialize the object into a dictionary."""
+
+ pass
+
+ @abc.abstractmethod
+ def generate(self):
+ """Generate the resource."""
+
+ pass
+
+ def remove(self, module):
+ """Remove the resource from the filesystem."""
+
+ try:
+ os.remove(self.path)
+ self.changed = True
+ except OSError as exc:
+ if exc.errno != errno.ENOENT:
+ raise OpenSSLObjectError(exc)
+ else:
+ pass
+
+
+# #####################################################################################
+# #####################################################################################
+# This has been extracted from the OpenSSL project's objects.txt:
+# https://github.com/openssl/openssl/blob/9537fe5757bb07761fa275d779bbd40bcf5530e4/crypto/objects/objects.txt
+# Extracted with https://gist.github.com/felixfontein/376748017ad65ead093d56a45a5bf376
+#
+# In case the following data structure has any copyrightable content, note that it is licensed as follows:
+# Copyright (c) the OpenSSL contributors
+# Licensed under the Apache License 2.0
+# https://github.com/openssl/openssl/blob/master/LICENSE
+_OID_MAP = {
+ '0': ('itu-t', 'ITU-T', 'ccitt'),
+ '0.3.4401.5': ('ntt-ds', ),
+ '0.3.4401.5.3.1.9': ('camellia', ),
+ '0.3.4401.5.3.1.9.1': ('camellia-128-ecb', 'CAMELLIA-128-ECB'),
+ '0.3.4401.5.3.1.9.3': ('camellia-128-ofb', 'CAMELLIA-128-OFB'),
+ '0.3.4401.5.3.1.9.4': ('camellia-128-cfb', 'CAMELLIA-128-CFB'),
+ '0.3.4401.5.3.1.9.6': ('camellia-128-gcm', 'CAMELLIA-128-GCM'),
+ '0.3.4401.5.3.1.9.7': ('camellia-128-ccm', 'CAMELLIA-128-CCM'),
+ '0.3.4401.5.3.1.9.9': ('camellia-128-ctr', 'CAMELLIA-128-CTR'),
+ '0.3.4401.5.3.1.9.10': ('camellia-128-cmac', 'CAMELLIA-128-CMAC'),
+ '0.3.4401.5.3.1.9.21': ('camellia-192-ecb', 'CAMELLIA-192-ECB'),
+ '0.3.4401.5.3.1.9.23': ('camellia-192-ofb', 'CAMELLIA-192-OFB'),
+ '0.3.4401.5.3.1.9.24': ('camellia-192-cfb', 'CAMELLIA-192-CFB'),
+ '0.3.4401.5.3.1.9.26': ('camellia-192-gcm', 'CAMELLIA-192-GCM'),
+ '0.3.4401.5.3.1.9.27': ('camellia-192-ccm', 'CAMELLIA-192-CCM'),
+ '0.3.4401.5.3.1.9.29': ('camellia-192-ctr', 'CAMELLIA-192-CTR'),
+ '0.3.4401.5.3.1.9.30': ('camellia-192-cmac', 'CAMELLIA-192-CMAC'),
+ '0.3.4401.5.3.1.9.41': ('camellia-256-ecb', 'CAMELLIA-256-ECB'),
+ '0.3.4401.5.3.1.9.43': ('camellia-256-ofb', 'CAMELLIA-256-OFB'),
+ '0.3.4401.5.3.1.9.44': ('camellia-256-cfb', 'CAMELLIA-256-CFB'),
+ '0.3.4401.5.3.1.9.46': ('camellia-256-gcm', 'CAMELLIA-256-GCM'),
+ '0.3.4401.5.3.1.9.47': ('camellia-256-ccm', 'CAMELLIA-256-CCM'),
+ '0.3.4401.5.3.1.9.49': ('camellia-256-ctr', 'CAMELLIA-256-CTR'),
+ '0.3.4401.5.3.1.9.50': ('camellia-256-cmac', 'CAMELLIA-256-CMAC'),
+ '0.9': ('data', ),
+ '0.9.2342': ('pss', ),
+ '0.9.2342.19200300': ('ucl', ),
+ '0.9.2342.19200300.100': ('pilot', ),
+ '0.9.2342.19200300.100.1': ('pilotAttributeType', ),
+ '0.9.2342.19200300.100.1.1': ('userId', 'UID'),
+ '0.9.2342.19200300.100.1.2': ('textEncodedORAddress', ),
+ '0.9.2342.19200300.100.1.3': ('rfc822Mailbox', 'mail'),
+ '0.9.2342.19200300.100.1.4': ('info', ),
+ '0.9.2342.19200300.100.1.5': ('favouriteDrink', ),
+ '0.9.2342.19200300.100.1.6': ('roomNumber', ),
+ '0.9.2342.19200300.100.1.7': ('photo', ),
+ '0.9.2342.19200300.100.1.8': ('userClass', ),
+ '0.9.2342.19200300.100.1.9': ('host', ),
+ '0.9.2342.19200300.100.1.10': ('manager', ),
+ '0.9.2342.19200300.100.1.11': ('documentIdentifier', ),
+ '0.9.2342.19200300.100.1.12': ('documentTitle', ),
+ '0.9.2342.19200300.100.1.13': ('documentVersion', ),
+ '0.9.2342.19200300.100.1.14': ('documentAuthor', ),
+ '0.9.2342.19200300.100.1.15': ('documentLocation', ),
+ '0.9.2342.19200300.100.1.20': ('homeTelephoneNumber', ),
+ '0.9.2342.19200300.100.1.21': ('secretary', ),
+ '0.9.2342.19200300.100.1.22': ('otherMailbox', ),
+ '0.9.2342.19200300.100.1.23': ('lastModifiedTime', ),
+ '0.9.2342.19200300.100.1.24': ('lastModifiedBy', ),
+ '0.9.2342.19200300.100.1.25': ('domainComponent', 'DC'),
+ '0.9.2342.19200300.100.1.26': ('aRecord', ),
+ '0.9.2342.19200300.100.1.27': ('pilotAttributeType27', ),
+ '0.9.2342.19200300.100.1.28': ('mXRecord', ),
+ '0.9.2342.19200300.100.1.29': ('nSRecord', ),
+ '0.9.2342.19200300.100.1.30': ('sOARecord', ),
+ '0.9.2342.19200300.100.1.31': ('cNAMERecord', ),
+ '0.9.2342.19200300.100.1.37': ('associatedDomain', ),
+ '0.9.2342.19200300.100.1.38': ('associatedName', ),
+ '0.9.2342.19200300.100.1.39': ('homePostalAddress', ),
+ '0.9.2342.19200300.100.1.40': ('personalTitle', ),
+ '0.9.2342.19200300.100.1.41': ('mobileTelephoneNumber', ),
+ '0.9.2342.19200300.100.1.42': ('pagerTelephoneNumber', ),
+ '0.9.2342.19200300.100.1.43': ('friendlyCountryName', ),
+ '0.9.2342.19200300.100.1.44': ('uniqueIdentifier', 'uid'),
+ '0.9.2342.19200300.100.1.45': ('organizationalStatus', ),
+ '0.9.2342.19200300.100.1.46': ('janetMailbox', ),
+ '0.9.2342.19200300.100.1.47': ('mailPreferenceOption', ),
+ '0.9.2342.19200300.100.1.48': ('buildingName', ),
+ '0.9.2342.19200300.100.1.49': ('dSAQuality', ),
+ '0.9.2342.19200300.100.1.50': ('singleLevelQuality', ),
+ '0.9.2342.19200300.100.1.51': ('subtreeMinimumQuality', ),
+ '0.9.2342.19200300.100.1.52': ('subtreeMaximumQuality', ),
+ '0.9.2342.19200300.100.1.53': ('personalSignature', ),
+ '0.9.2342.19200300.100.1.54': ('dITRedirect', ),
+ '0.9.2342.19200300.100.1.55': ('audio', ),
+ '0.9.2342.19200300.100.1.56': ('documentPublisher', ),
+ '0.9.2342.19200300.100.3': ('pilotAttributeSyntax', ),
+ '0.9.2342.19200300.100.3.4': ('iA5StringSyntax', ),
+ '0.9.2342.19200300.100.3.5': ('caseIgnoreIA5StringSyntax', ),
+ '0.9.2342.19200300.100.4': ('pilotObjectClass', ),
+ '0.9.2342.19200300.100.4.3': ('pilotObject', ),
+ '0.9.2342.19200300.100.4.4': ('pilotPerson', ),
+ '0.9.2342.19200300.100.4.5': ('account', ),
+ '0.9.2342.19200300.100.4.6': ('document', ),
+ '0.9.2342.19200300.100.4.7': ('room', ),
+ '0.9.2342.19200300.100.4.9': ('documentSeries', ),
+ '0.9.2342.19200300.100.4.13': ('Domain', 'domain'),
+ '0.9.2342.19200300.100.4.14': ('rFC822localPart', ),
+ '0.9.2342.19200300.100.4.15': ('dNSDomain', ),
+ '0.9.2342.19200300.100.4.17': ('domainRelatedObject', ),
+ '0.9.2342.19200300.100.4.18': ('friendlyCountry', ),
+ '0.9.2342.19200300.100.4.19': ('simpleSecurityObject', ),
+ '0.9.2342.19200300.100.4.20': ('pilotOrganization', ),
+ '0.9.2342.19200300.100.4.21': ('pilotDSA', ),
+ '0.9.2342.19200300.100.4.22': ('qualityLabelledData', ),
+ '0.9.2342.19200300.100.10': ('pilotGroups', ),
+ '1': ('iso', 'ISO'),
+ '1.0.9797.3.4': ('gmac', 'GMAC'),
+ '1.0.10118.3.0.55': ('whirlpool', ),
+ '1.2': ('ISO Member Body', 'member-body'),
+ '1.2.156': ('ISO CN Member Body', 'ISO-CN'),
+ '1.2.156.10197': ('oscca', ),
+ '1.2.156.10197.1': ('sm-scheme', ),
+ '1.2.156.10197.1.104.1': ('sm4-ecb', 'SM4-ECB'),
+ '1.2.156.10197.1.104.2': ('sm4-cbc', 'SM4-CBC'),
+ '1.2.156.10197.1.104.3': ('sm4-ofb', 'SM4-OFB'),
+ '1.2.156.10197.1.104.4': ('sm4-cfb', 'SM4-CFB'),
+ '1.2.156.10197.1.104.5': ('sm4-cfb1', 'SM4-CFB1'),
+ '1.2.156.10197.1.104.6': ('sm4-cfb8', 'SM4-CFB8'),
+ '1.2.156.10197.1.104.7': ('sm4-ctr', 'SM4-CTR'),
+ '1.2.156.10197.1.301': ('sm2', 'SM2'),
+ '1.2.156.10197.1.401': ('sm3', 'SM3'),
+ '1.2.156.10197.1.501': ('SM2-with-SM3', 'SM2-SM3'),
+ '1.2.156.10197.1.504': ('sm3WithRSAEncryption', 'RSA-SM3'),
+ '1.2.392.200011.61.1.1.1.2': ('camellia-128-cbc', 'CAMELLIA-128-CBC'),
+ '1.2.392.200011.61.1.1.1.3': ('camellia-192-cbc', 'CAMELLIA-192-CBC'),
+ '1.2.392.200011.61.1.1.1.4': ('camellia-256-cbc', 'CAMELLIA-256-CBC'),
+ '1.2.392.200011.61.1.1.3.2': ('id-camellia128-wrap', ),
+ '1.2.392.200011.61.1.1.3.3': ('id-camellia192-wrap', ),
+ '1.2.392.200011.61.1.1.3.4': ('id-camellia256-wrap', ),
+ '1.2.410.200004': ('kisa', 'KISA'),
+ '1.2.410.200004.1.3': ('seed-ecb', 'SEED-ECB'),
+ '1.2.410.200004.1.4': ('seed-cbc', 'SEED-CBC'),
+ '1.2.410.200004.1.5': ('seed-cfb', 'SEED-CFB'),
+ '1.2.410.200004.1.6': ('seed-ofb', 'SEED-OFB'),
+ '1.2.410.200046.1.1': ('aria', ),
+ '1.2.410.200046.1.1.1': ('aria-128-ecb', 'ARIA-128-ECB'),
+ '1.2.410.200046.1.1.2': ('aria-128-cbc', 'ARIA-128-CBC'),
+ '1.2.410.200046.1.1.3': ('aria-128-cfb', 'ARIA-128-CFB'),
+ '1.2.410.200046.1.1.4': ('aria-128-ofb', 'ARIA-128-OFB'),
+ '1.2.410.200046.1.1.5': ('aria-128-ctr', 'ARIA-128-CTR'),
+ '1.2.410.200046.1.1.6': ('aria-192-ecb', 'ARIA-192-ECB'),
+ '1.2.410.200046.1.1.7': ('aria-192-cbc', 'ARIA-192-CBC'),
+ '1.2.410.200046.1.1.8': ('aria-192-cfb', 'ARIA-192-CFB'),
+ '1.2.410.200046.1.1.9': ('aria-192-ofb', 'ARIA-192-OFB'),
+ '1.2.410.200046.1.1.10': ('aria-192-ctr', 'ARIA-192-CTR'),
+ '1.2.410.200046.1.1.11': ('aria-256-ecb', 'ARIA-256-ECB'),
+ '1.2.410.200046.1.1.12': ('aria-256-cbc', 'ARIA-256-CBC'),
+ '1.2.410.200046.1.1.13': ('aria-256-cfb', 'ARIA-256-CFB'),
+ '1.2.410.200046.1.1.14': ('aria-256-ofb', 'ARIA-256-OFB'),
+ '1.2.410.200046.1.1.15': ('aria-256-ctr', 'ARIA-256-CTR'),
+ '1.2.410.200046.1.1.34': ('aria-128-gcm', 'ARIA-128-GCM'),
+ '1.2.410.200046.1.1.35': ('aria-192-gcm', 'ARIA-192-GCM'),
+ '1.2.410.200046.1.1.36': ('aria-256-gcm', 'ARIA-256-GCM'),
+ '1.2.410.200046.1.1.37': ('aria-128-ccm', 'ARIA-128-CCM'),
+ '1.2.410.200046.1.1.38': ('aria-192-ccm', 'ARIA-192-CCM'),
+ '1.2.410.200046.1.1.39': ('aria-256-ccm', 'ARIA-256-CCM'),
+ '1.2.643.2.2': ('cryptopro', ),
+ '1.2.643.2.2.3': ('GOST R 34.11-94 with GOST R 34.10-2001', 'id-GostR3411-94-with-GostR3410-2001'),
+ '1.2.643.2.2.4': ('GOST R 34.11-94 with GOST R 34.10-94', 'id-GostR3411-94-with-GostR3410-94'),
+ '1.2.643.2.2.9': ('GOST R 34.11-94', 'md_gost94'),
+ '1.2.643.2.2.10': ('HMAC GOST 34.11-94', 'id-HMACGostR3411-94'),
+ '1.2.643.2.2.14.0': ('id-Gost28147-89-None-KeyMeshing', ),
+ '1.2.643.2.2.14.1': ('id-Gost28147-89-CryptoPro-KeyMeshing', ),
+ '1.2.643.2.2.19': ('GOST R 34.10-2001', 'gost2001'),
+ '1.2.643.2.2.20': ('GOST R 34.10-94', 'gost94'),
+ '1.2.643.2.2.20.1': ('id-GostR3410-94-a', ),
+ '1.2.643.2.2.20.2': ('id-GostR3410-94-aBis', ),
+ '1.2.643.2.2.20.3': ('id-GostR3410-94-b', ),
+ '1.2.643.2.2.20.4': ('id-GostR3410-94-bBis', ),
+ '1.2.643.2.2.21': ('GOST 28147-89', 'gost89'),
+ '1.2.643.2.2.22': ('GOST 28147-89 MAC', 'gost-mac'),
+ '1.2.643.2.2.23': ('GOST R 34.11-94 PRF', 'prf-gostr3411-94'),
+ '1.2.643.2.2.30.0': ('id-GostR3411-94-TestParamSet', ),
+ '1.2.643.2.2.30.1': ('id-GostR3411-94-CryptoProParamSet', ),
+ '1.2.643.2.2.31.0': ('id-Gost28147-89-TestParamSet', ),
+ '1.2.643.2.2.31.1': ('id-Gost28147-89-CryptoPro-A-ParamSet', ),
+ '1.2.643.2.2.31.2': ('id-Gost28147-89-CryptoPro-B-ParamSet', ),
+ '1.2.643.2.2.31.3': ('id-Gost28147-89-CryptoPro-C-ParamSet', ),
+ '1.2.643.2.2.31.4': ('id-Gost28147-89-CryptoPro-D-ParamSet', ),
+ '1.2.643.2.2.31.5': ('id-Gost28147-89-CryptoPro-Oscar-1-1-ParamSet', ),
+ '1.2.643.2.2.31.6': ('id-Gost28147-89-CryptoPro-Oscar-1-0-ParamSet', ),
+ '1.2.643.2.2.31.7': ('id-Gost28147-89-CryptoPro-RIC-1-ParamSet', ),
+ '1.2.643.2.2.32.0': ('id-GostR3410-94-TestParamSet', ),
+ '1.2.643.2.2.32.2': ('id-GostR3410-94-CryptoPro-A-ParamSet', ),
+ '1.2.643.2.2.32.3': ('id-GostR3410-94-CryptoPro-B-ParamSet', ),
+ '1.2.643.2.2.32.4': ('id-GostR3410-94-CryptoPro-C-ParamSet', ),
+ '1.2.643.2.2.32.5': ('id-GostR3410-94-CryptoPro-D-ParamSet', ),
+ '1.2.643.2.2.33.1': ('id-GostR3410-94-CryptoPro-XchA-ParamSet', ),
+ '1.2.643.2.2.33.2': ('id-GostR3410-94-CryptoPro-XchB-ParamSet', ),
+ '1.2.643.2.2.33.3': ('id-GostR3410-94-CryptoPro-XchC-ParamSet', ),
+ '1.2.643.2.2.35.0': ('id-GostR3410-2001-TestParamSet', ),
+ '1.2.643.2.2.35.1': ('id-GostR3410-2001-CryptoPro-A-ParamSet', ),
+ '1.2.643.2.2.35.2': ('id-GostR3410-2001-CryptoPro-B-ParamSet', ),
+ '1.2.643.2.2.35.3': ('id-GostR3410-2001-CryptoPro-C-ParamSet', ),
+ '1.2.643.2.2.36.0': ('id-GostR3410-2001-CryptoPro-XchA-ParamSet', ),
+ '1.2.643.2.2.36.1': ('id-GostR3410-2001-CryptoPro-XchB-ParamSet', ),
+ '1.2.643.2.2.98': ('GOST R 34.10-2001 DH', 'id-GostR3410-2001DH'),
+ '1.2.643.2.2.99': ('GOST R 34.10-94 DH', 'id-GostR3410-94DH'),
+ '1.2.643.2.9': ('cryptocom', ),
+ '1.2.643.2.9.1.3.3': ('GOST R 34.11-94 with GOST R 34.10-94 Cryptocom', 'id-GostR3411-94-with-GostR3410-94-cc'),
+ '1.2.643.2.9.1.3.4': ('GOST R 34.11-94 with GOST R 34.10-2001 Cryptocom', 'id-GostR3411-94-with-GostR3410-2001-cc'),
+ '1.2.643.2.9.1.5.3': ('GOST 34.10-94 Cryptocom', 'gost94cc'),
+ '1.2.643.2.9.1.5.4': ('GOST 34.10-2001 Cryptocom', 'gost2001cc'),
+ '1.2.643.2.9.1.6.1': ('GOST 28147-89 Cryptocom ParamSet', 'id-Gost28147-89-cc'),
+ '1.2.643.2.9.1.8.1': ('GOST R 3410-2001 Parameter Set Cryptocom', 'id-GostR3410-2001-ParamSet-cc'),
+ '1.2.643.3.131.1.1': ('INN', 'INN'),
+ '1.2.643.7.1': ('id-tc26', ),
+ '1.2.643.7.1.1': ('id-tc26-algorithms', ),
+ '1.2.643.7.1.1.1': ('id-tc26-sign', ),
+ '1.2.643.7.1.1.1.1': ('GOST R 34.10-2012 with 256 bit modulus', 'gost2012_256'),
+ '1.2.643.7.1.1.1.2': ('GOST R 34.10-2012 with 512 bit modulus', 'gost2012_512'),
+ '1.2.643.7.1.1.2': ('id-tc26-digest', ),
+ '1.2.643.7.1.1.2.2': ('GOST R 34.11-2012 with 256 bit hash', 'md_gost12_256'),
+ '1.2.643.7.1.1.2.3': ('GOST R 34.11-2012 with 512 bit hash', 'md_gost12_512'),
+ '1.2.643.7.1.1.3': ('id-tc26-signwithdigest', ),
+ '1.2.643.7.1.1.3.2': ('GOST R 34.10-2012 with GOST R 34.11-2012 (256 bit)', 'id-tc26-signwithdigest-gost3410-2012-256'),
+ '1.2.643.7.1.1.3.3': ('GOST R 34.10-2012 with GOST R 34.11-2012 (512 bit)', 'id-tc26-signwithdigest-gost3410-2012-512'),
+ '1.2.643.7.1.1.4': ('id-tc26-mac', ),
+ '1.2.643.7.1.1.4.1': ('HMAC GOST 34.11-2012 256 bit', 'id-tc26-hmac-gost-3411-2012-256'),
+ '1.2.643.7.1.1.4.2': ('HMAC GOST 34.11-2012 512 bit', 'id-tc26-hmac-gost-3411-2012-512'),
+ '1.2.643.7.1.1.5': ('id-tc26-cipher', ),
+ '1.2.643.7.1.1.5.1': ('id-tc26-cipher-gostr3412-2015-magma', ),
+ '1.2.643.7.1.1.5.1.1': ('id-tc26-cipher-gostr3412-2015-magma-ctracpkm', ),
+ '1.2.643.7.1.1.5.1.2': ('id-tc26-cipher-gostr3412-2015-magma-ctracpkm-omac', ),
+ '1.2.643.7.1.1.5.2': ('id-tc26-cipher-gostr3412-2015-kuznyechik', ),
+ '1.2.643.7.1.1.5.2.1': ('id-tc26-cipher-gostr3412-2015-kuznyechik-ctracpkm', ),
+ '1.2.643.7.1.1.5.2.2': ('id-tc26-cipher-gostr3412-2015-kuznyechik-ctracpkm-omac', ),
+ '1.2.643.7.1.1.6': ('id-tc26-agreement', ),
+ '1.2.643.7.1.1.6.1': ('id-tc26-agreement-gost-3410-2012-256', ),
+ '1.2.643.7.1.1.6.2': ('id-tc26-agreement-gost-3410-2012-512', ),
+ '1.2.643.7.1.1.7': ('id-tc26-wrap', ),
+ '1.2.643.7.1.1.7.1': ('id-tc26-wrap-gostr3412-2015-magma', ),
+ '1.2.643.7.1.1.7.1.1': ('id-tc26-wrap-gostr3412-2015-magma-kexp15', 'id-tc26-wrap-gostr3412-2015-kuznyechik-kexp15'),
+ '1.2.643.7.1.1.7.2': ('id-tc26-wrap-gostr3412-2015-kuznyechik', ),
+ '1.2.643.7.1.2': ('id-tc26-constants', ),
+ '1.2.643.7.1.2.1': ('id-tc26-sign-constants', ),
+ '1.2.643.7.1.2.1.1': ('id-tc26-gost-3410-2012-256-constants', ),
+ '1.2.643.7.1.2.1.1.1': ('GOST R 34.10-2012 (256 bit) ParamSet A', 'id-tc26-gost-3410-2012-256-paramSetA'),
+ '1.2.643.7.1.2.1.1.2': ('GOST R 34.10-2012 (256 bit) ParamSet B', 'id-tc26-gost-3410-2012-256-paramSetB'),
+ '1.2.643.7.1.2.1.1.3': ('GOST R 34.10-2012 (256 bit) ParamSet C', 'id-tc26-gost-3410-2012-256-paramSetC'),
+ '1.2.643.7.1.2.1.1.4': ('GOST R 34.10-2012 (256 bit) ParamSet D', 'id-tc26-gost-3410-2012-256-paramSetD'),
+ '1.2.643.7.1.2.1.2': ('id-tc26-gost-3410-2012-512-constants', ),
+ '1.2.643.7.1.2.1.2.0': ('GOST R 34.10-2012 (512 bit) testing parameter set', 'id-tc26-gost-3410-2012-512-paramSetTest'),
+ '1.2.643.7.1.2.1.2.1': ('GOST R 34.10-2012 (512 bit) ParamSet A', 'id-tc26-gost-3410-2012-512-paramSetA'),
+ '1.2.643.7.1.2.1.2.2': ('GOST R 34.10-2012 (512 bit) ParamSet B', 'id-tc26-gost-3410-2012-512-paramSetB'),
+ '1.2.643.7.1.2.1.2.3': ('GOST R 34.10-2012 (512 bit) ParamSet C', 'id-tc26-gost-3410-2012-512-paramSetC'),
+ '1.2.643.7.1.2.2': ('id-tc26-digest-constants', ),
+ '1.2.643.7.1.2.5': ('id-tc26-cipher-constants', ),
+ '1.2.643.7.1.2.5.1': ('id-tc26-gost-28147-constants', ),
+ '1.2.643.7.1.2.5.1.1': ('GOST 28147-89 TC26 parameter set', 'id-tc26-gost-28147-param-Z'),
+ '1.2.643.100.1': ('OGRN', 'OGRN'),
+ '1.2.643.100.3': ('SNILS', 'SNILS'),
+ '1.2.643.100.111': ('Signing Tool of Subject', 'subjectSignTool'),
+ '1.2.643.100.112': ('Signing Tool of Issuer', 'issuerSignTool'),
+ '1.2.804': ('ISO-UA', ),
+ '1.2.804.2.1.1.1': ('ua-pki', ),
+ '1.2.804.2.1.1.1.1.1.1': ('DSTU Gost 28147-2009', 'dstu28147'),
+ '1.2.804.2.1.1.1.1.1.1.2': ('DSTU Gost 28147-2009 OFB mode', 'dstu28147-ofb'),
+ '1.2.804.2.1.1.1.1.1.1.3': ('DSTU Gost 28147-2009 CFB mode', 'dstu28147-cfb'),
+ '1.2.804.2.1.1.1.1.1.1.5': ('DSTU Gost 28147-2009 key wrap', 'dstu28147-wrap'),
+ '1.2.804.2.1.1.1.1.1.2': ('HMAC DSTU Gost 34311-95', 'hmacWithDstu34311'),
+ '1.2.804.2.1.1.1.1.2.1': ('DSTU Gost 34311-95', 'dstu34311'),
+ '1.2.804.2.1.1.1.1.3.1.1': ('DSTU 4145-2002 little endian', 'dstu4145le'),
+ '1.2.804.2.1.1.1.1.3.1.1.1.1': ('DSTU 4145-2002 big endian', 'dstu4145be'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.0': ('DSTU curve 0', 'uacurve0'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.1': ('DSTU curve 1', 'uacurve1'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.2': ('DSTU curve 2', 'uacurve2'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.3': ('DSTU curve 3', 'uacurve3'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.4': ('DSTU curve 4', 'uacurve4'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.5': ('DSTU curve 5', 'uacurve5'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.6': ('DSTU curve 6', 'uacurve6'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.7': ('DSTU curve 7', 'uacurve7'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.8': ('DSTU curve 8', 'uacurve8'),
+ '1.2.804.2.1.1.1.1.3.1.1.2.9': ('DSTU curve 9', 'uacurve9'),
+ '1.2.840': ('ISO US Member Body', 'ISO-US'),
+ '1.2.840.10040': ('X9.57', 'X9-57'),
+ '1.2.840.10040.2': ('holdInstruction', ),
+ '1.2.840.10040.2.1': ('Hold Instruction None', 'holdInstructionNone'),
+ '1.2.840.10040.2.2': ('Hold Instruction Call Issuer', 'holdInstructionCallIssuer'),
+ '1.2.840.10040.2.3': ('Hold Instruction Reject', 'holdInstructionReject'),
+ '1.2.840.10040.4': ('X9.57 CM ?', 'X9cm'),
+ '1.2.840.10040.4.1': ('dsaEncryption', 'DSA'),
+ '1.2.840.10040.4.3': ('dsaWithSHA1', 'DSA-SHA1'),
+ '1.2.840.10045': ('ANSI X9.62', 'ansi-X9-62'),
+ '1.2.840.10045.1': ('id-fieldType', ),
+ '1.2.840.10045.1.1': ('prime-field', ),
+ '1.2.840.10045.1.2': ('characteristic-two-field', ),
+ '1.2.840.10045.1.2.3': ('id-characteristic-two-basis', ),
+ '1.2.840.10045.1.2.3.1': ('onBasis', ),
+ '1.2.840.10045.1.2.3.2': ('tpBasis', ),
+ '1.2.840.10045.1.2.3.3': ('ppBasis', ),
+ '1.2.840.10045.2': ('id-publicKeyType', ),
+ '1.2.840.10045.2.1': ('id-ecPublicKey', ),
+ '1.2.840.10045.3': ('ellipticCurve', ),
+ '1.2.840.10045.3.0': ('c-TwoCurve', ),
+ '1.2.840.10045.3.0.1': ('c2pnb163v1', ),
+ '1.2.840.10045.3.0.2': ('c2pnb163v2', ),
+ '1.2.840.10045.3.0.3': ('c2pnb163v3', ),
+ '1.2.840.10045.3.0.4': ('c2pnb176v1', ),
+ '1.2.840.10045.3.0.5': ('c2tnb191v1', ),
+ '1.2.840.10045.3.0.6': ('c2tnb191v2', ),
+ '1.2.840.10045.3.0.7': ('c2tnb191v3', ),
+ '1.2.840.10045.3.0.8': ('c2onb191v4', ),
+ '1.2.840.10045.3.0.9': ('c2onb191v5', ),
+ '1.2.840.10045.3.0.10': ('c2pnb208w1', ),
+ '1.2.840.10045.3.0.11': ('c2tnb239v1', ),
+ '1.2.840.10045.3.0.12': ('c2tnb239v2', ),
+ '1.2.840.10045.3.0.13': ('c2tnb239v3', ),
+ '1.2.840.10045.3.0.14': ('c2onb239v4', ),
+ '1.2.840.10045.3.0.15': ('c2onb239v5', ),
+ '1.2.840.10045.3.0.16': ('c2pnb272w1', ),
+ '1.2.840.10045.3.0.17': ('c2pnb304w1', ),
+ '1.2.840.10045.3.0.18': ('c2tnb359v1', ),
+ '1.2.840.10045.3.0.19': ('c2pnb368w1', ),
+ '1.2.840.10045.3.0.20': ('c2tnb431r1', ),
+ '1.2.840.10045.3.1': ('primeCurve', ),
+ '1.2.840.10045.3.1.1': ('prime192v1', ),
+ '1.2.840.10045.3.1.2': ('prime192v2', ),
+ '1.2.840.10045.3.1.3': ('prime192v3', ),
+ '1.2.840.10045.3.1.4': ('prime239v1', ),
+ '1.2.840.10045.3.1.5': ('prime239v2', ),
+ '1.2.840.10045.3.1.6': ('prime239v3', ),
+ '1.2.840.10045.3.1.7': ('prime256v1', ),
+ '1.2.840.10045.4': ('id-ecSigType', ),
+ '1.2.840.10045.4.1': ('ecdsa-with-SHA1', ),
+ '1.2.840.10045.4.2': ('ecdsa-with-Recommended', ),
+ '1.2.840.10045.4.3': ('ecdsa-with-Specified', ),
+ '1.2.840.10045.4.3.1': ('ecdsa-with-SHA224', ),
+ '1.2.840.10045.4.3.2': ('ecdsa-with-SHA256', ),
+ '1.2.840.10045.4.3.3': ('ecdsa-with-SHA384', ),
+ '1.2.840.10045.4.3.4': ('ecdsa-with-SHA512', ),
+ '1.2.840.10046.2.1': ('X9.42 DH', 'dhpublicnumber'),
+ '1.2.840.113533.7.66.10': ('cast5-cbc', 'CAST5-CBC'),
+ '1.2.840.113533.7.66.12': ('pbeWithMD5AndCast5CBC', ),
+ '1.2.840.113533.7.66.13': ('password based MAC', 'id-PasswordBasedMAC'),
+ '1.2.840.113533.7.66.30': ('Diffie-Hellman based MAC', 'id-DHBasedMac'),
+ '1.2.840.113549': ('RSA Data Security, Inc.', 'rsadsi'),
+ '1.2.840.113549.1': ('RSA Data Security, Inc. PKCS', 'pkcs'),
+ '1.2.840.113549.1.1': ('pkcs1', ),
+ '1.2.840.113549.1.1.1': ('rsaEncryption', ),
+ '1.2.840.113549.1.1.2': ('md2WithRSAEncryption', 'RSA-MD2'),
+ '1.2.840.113549.1.1.3': ('md4WithRSAEncryption', 'RSA-MD4'),
+ '1.2.840.113549.1.1.4': ('md5WithRSAEncryption', 'RSA-MD5'),
+ '1.2.840.113549.1.1.5': ('sha1WithRSAEncryption', 'RSA-SHA1'),
+ '1.2.840.113549.1.1.6': ('rsaOAEPEncryptionSET', ),
+ '1.2.840.113549.1.1.7': ('rsaesOaep', 'RSAES-OAEP'),
+ '1.2.840.113549.1.1.8': ('mgf1', 'MGF1'),
+ '1.2.840.113549.1.1.9': ('pSpecified', 'PSPECIFIED'),
+ '1.2.840.113549.1.1.10': ('rsassaPss', 'RSASSA-PSS'),
+ '1.2.840.113549.1.1.11': ('sha256WithRSAEncryption', 'RSA-SHA256'),
+ '1.2.840.113549.1.1.12': ('sha384WithRSAEncryption', 'RSA-SHA384'),
+ '1.2.840.113549.1.1.13': ('sha512WithRSAEncryption', 'RSA-SHA512'),
+ '1.2.840.113549.1.1.14': ('sha224WithRSAEncryption', 'RSA-SHA224'),
+ '1.2.840.113549.1.1.15': ('sha512-224WithRSAEncryption', 'RSA-SHA512/224'),
+ '1.2.840.113549.1.1.16': ('sha512-256WithRSAEncryption', 'RSA-SHA512/256'),
+ '1.2.840.113549.1.3': ('pkcs3', ),
+ '1.2.840.113549.1.3.1': ('dhKeyAgreement', ),
+ '1.2.840.113549.1.5': ('pkcs5', ),
+ '1.2.840.113549.1.5.1': ('pbeWithMD2AndDES-CBC', 'PBE-MD2-DES'),
+ '1.2.840.113549.1.5.3': ('pbeWithMD5AndDES-CBC', 'PBE-MD5-DES'),
+ '1.2.840.113549.1.5.4': ('pbeWithMD2AndRC2-CBC', 'PBE-MD2-RC2-64'),
+ '1.2.840.113549.1.5.6': ('pbeWithMD5AndRC2-CBC', 'PBE-MD5-RC2-64'),
+ '1.2.840.113549.1.5.10': ('pbeWithSHA1AndDES-CBC', 'PBE-SHA1-DES'),
+ '1.2.840.113549.1.5.11': ('pbeWithSHA1AndRC2-CBC', 'PBE-SHA1-RC2-64'),
+ '1.2.840.113549.1.5.12': ('PBKDF2', ),
+ '1.2.840.113549.1.5.13': ('PBES2', ),
+ '1.2.840.113549.1.5.14': ('PBMAC1', ),
+ '1.2.840.113549.1.7': ('pkcs7', ),
+ '1.2.840.113549.1.7.1': ('pkcs7-data', ),
+ '1.2.840.113549.1.7.2': ('pkcs7-signedData', ),
+ '1.2.840.113549.1.7.3': ('pkcs7-envelopedData', ),
+ '1.2.840.113549.1.7.4': ('pkcs7-signedAndEnvelopedData', ),
+ '1.2.840.113549.1.7.5': ('pkcs7-digestData', ),
+ '1.2.840.113549.1.7.6': ('pkcs7-encryptedData', ),
+ '1.2.840.113549.1.9': ('pkcs9', ),
+ '1.2.840.113549.1.9.1': ('emailAddress', ),
+ '1.2.840.113549.1.9.2': ('unstructuredName', ),
+ '1.2.840.113549.1.9.3': ('contentType', ),
+ '1.2.840.113549.1.9.4': ('messageDigest', ),
+ '1.2.840.113549.1.9.5': ('signingTime', ),
+ '1.2.840.113549.1.9.6': ('countersignature', ),
+ '1.2.840.113549.1.9.7': ('challengePassword', ),
+ '1.2.840.113549.1.9.8': ('unstructuredAddress', ),
+ '1.2.840.113549.1.9.9': ('extendedCertificateAttributes', ),
+ '1.2.840.113549.1.9.14': ('Extension Request', 'extReq'),
+ '1.2.840.113549.1.9.15': ('S/MIME Capabilities', 'SMIME-CAPS'),
+ '1.2.840.113549.1.9.16': ('S/MIME', 'SMIME'),
+ '1.2.840.113549.1.9.16.0': ('id-smime-mod', ),
+ '1.2.840.113549.1.9.16.0.1': ('id-smime-mod-cms', ),
+ '1.2.840.113549.1.9.16.0.2': ('id-smime-mod-ess', ),
+ '1.2.840.113549.1.9.16.0.3': ('id-smime-mod-oid', ),
+ '1.2.840.113549.1.9.16.0.4': ('id-smime-mod-msg-v3', ),
+ '1.2.840.113549.1.9.16.0.5': ('id-smime-mod-ets-eSignature-88', ),
+ '1.2.840.113549.1.9.16.0.6': ('id-smime-mod-ets-eSignature-97', ),
+ '1.2.840.113549.1.9.16.0.7': ('id-smime-mod-ets-eSigPolicy-88', ),
+ '1.2.840.113549.1.9.16.0.8': ('id-smime-mod-ets-eSigPolicy-97', ),
+ '1.2.840.113549.1.9.16.1': ('id-smime-ct', ),
+ '1.2.840.113549.1.9.16.1.1': ('id-smime-ct-receipt', ),
+ '1.2.840.113549.1.9.16.1.2': ('id-smime-ct-authData', ),
+ '1.2.840.113549.1.9.16.1.3': ('id-smime-ct-publishCert', ),
+ '1.2.840.113549.1.9.16.1.4': ('id-smime-ct-TSTInfo', ),
+ '1.2.840.113549.1.9.16.1.5': ('id-smime-ct-TDTInfo', ),
+ '1.2.840.113549.1.9.16.1.6': ('id-smime-ct-contentInfo', ),
+ '1.2.840.113549.1.9.16.1.7': ('id-smime-ct-DVCSRequestData', ),
+ '1.2.840.113549.1.9.16.1.8': ('id-smime-ct-DVCSResponseData', ),
+ '1.2.840.113549.1.9.16.1.9': ('id-smime-ct-compressedData', ),
+ '1.2.840.113549.1.9.16.1.19': ('id-smime-ct-contentCollection', ),
+ '1.2.840.113549.1.9.16.1.23': ('id-smime-ct-authEnvelopedData', ),
+ '1.2.840.113549.1.9.16.1.27': ('id-ct-asciiTextWithCRLF', ),
+ '1.2.840.113549.1.9.16.1.28': ('id-ct-xml', ),
+ '1.2.840.113549.1.9.16.2': ('id-smime-aa', ),
+ '1.2.840.113549.1.9.16.2.1': ('id-smime-aa-receiptRequest', ),
+ '1.2.840.113549.1.9.16.2.2': ('id-smime-aa-securityLabel', ),
+ '1.2.840.113549.1.9.16.2.3': ('id-smime-aa-mlExpandHistory', ),
+ '1.2.840.113549.1.9.16.2.4': ('id-smime-aa-contentHint', ),
+ '1.2.840.113549.1.9.16.2.5': ('id-smime-aa-msgSigDigest', ),
+ '1.2.840.113549.1.9.16.2.6': ('id-smime-aa-encapContentType', ),
+ '1.2.840.113549.1.9.16.2.7': ('id-smime-aa-contentIdentifier', ),
+ '1.2.840.113549.1.9.16.2.8': ('id-smime-aa-macValue', ),
+ '1.2.840.113549.1.9.16.2.9': ('id-smime-aa-equivalentLabels', ),
+ '1.2.840.113549.1.9.16.2.10': ('id-smime-aa-contentReference', ),
+ '1.2.840.113549.1.9.16.2.11': ('id-smime-aa-encrypKeyPref', ),
+ '1.2.840.113549.1.9.16.2.12': ('id-smime-aa-signingCertificate', ),
+ '1.2.840.113549.1.9.16.2.13': ('id-smime-aa-smimeEncryptCerts', ),
+ '1.2.840.113549.1.9.16.2.14': ('id-smime-aa-timeStampToken', ),
+ '1.2.840.113549.1.9.16.2.15': ('id-smime-aa-ets-sigPolicyId', ),
+ '1.2.840.113549.1.9.16.2.16': ('id-smime-aa-ets-commitmentType', ),
+ '1.2.840.113549.1.9.16.2.17': ('id-smime-aa-ets-signerLocation', ),
+ '1.2.840.113549.1.9.16.2.18': ('id-smime-aa-ets-signerAttr', ),
+ '1.2.840.113549.1.9.16.2.19': ('id-smime-aa-ets-otherSigCert', ),
+ '1.2.840.113549.1.9.16.2.20': ('id-smime-aa-ets-contentTimestamp', ),
+ '1.2.840.113549.1.9.16.2.21': ('id-smime-aa-ets-CertificateRefs', ),
+ '1.2.840.113549.1.9.16.2.22': ('id-smime-aa-ets-RevocationRefs', ),
+ '1.2.840.113549.1.9.16.2.23': ('id-smime-aa-ets-certValues', ),
+ '1.2.840.113549.1.9.16.2.24': ('id-smime-aa-ets-revocationValues', ),
+ '1.2.840.113549.1.9.16.2.25': ('id-smime-aa-ets-escTimeStamp', ),
+ '1.2.840.113549.1.9.16.2.26': ('id-smime-aa-ets-certCRLTimestamp', ),
+ '1.2.840.113549.1.9.16.2.27': ('id-smime-aa-ets-archiveTimeStamp', ),
+ '1.2.840.113549.1.9.16.2.28': ('id-smime-aa-signatureType', ),
+ '1.2.840.113549.1.9.16.2.29': ('id-smime-aa-dvcs-dvc', ),
+ '1.2.840.113549.1.9.16.2.47': ('id-smime-aa-signingCertificateV2', ),
+ '1.2.840.113549.1.9.16.3': ('id-smime-alg', ),
+ '1.2.840.113549.1.9.16.3.1': ('id-smime-alg-ESDHwith3DES', ),
+ '1.2.840.113549.1.9.16.3.2': ('id-smime-alg-ESDHwithRC2', ),
+ '1.2.840.113549.1.9.16.3.3': ('id-smime-alg-3DESwrap', ),
+ '1.2.840.113549.1.9.16.3.4': ('id-smime-alg-RC2wrap', ),
+ '1.2.840.113549.1.9.16.3.5': ('id-smime-alg-ESDH', ),
+ '1.2.840.113549.1.9.16.3.6': ('id-smime-alg-CMS3DESwrap', ),
+ '1.2.840.113549.1.9.16.3.7': ('id-smime-alg-CMSRC2wrap', ),
+ '1.2.840.113549.1.9.16.3.8': ('zlib compression', 'ZLIB'),
+ '1.2.840.113549.1.9.16.3.9': ('id-alg-PWRI-KEK', ),
+ '1.2.840.113549.1.9.16.4': ('id-smime-cd', ),
+ '1.2.840.113549.1.9.16.4.1': ('id-smime-cd-ldap', ),
+ '1.2.840.113549.1.9.16.5': ('id-smime-spq', ),
+ '1.2.840.113549.1.9.16.5.1': ('id-smime-spq-ets-sqt-uri', ),
+ '1.2.840.113549.1.9.16.5.2': ('id-smime-spq-ets-sqt-unotice', ),
+ '1.2.840.113549.1.9.16.6': ('id-smime-cti', ),
+ '1.2.840.113549.1.9.16.6.1': ('id-smime-cti-ets-proofOfOrigin', ),
+ '1.2.840.113549.1.9.16.6.2': ('id-smime-cti-ets-proofOfReceipt', ),
+ '1.2.840.113549.1.9.16.6.3': ('id-smime-cti-ets-proofOfDelivery', ),
+ '1.2.840.113549.1.9.16.6.4': ('id-smime-cti-ets-proofOfSender', ),
+ '1.2.840.113549.1.9.16.6.5': ('id-smime-cti-ets-proofOfApproval', ),
+ '1.2.840.113549.1.9.16.6.6': ('id-smime-cti-ets-proofOfCreation', ),
+ '1.2.840.113549.1.9.20': ('friendlyName', ),
+ '1.2.840.113549.1.9.21': ('localKeyID', ),
+ '1.2.840.113549.1.9.22': ('certTypes', ),
+ '1.2.840.113549.1.9.22.1': ('x509Certificate', ),
+ '1.2.840.113549.1.9.22.2': ('sdsiCertificate', ),
+ '1.2.840.113549.1.9.23': ('crlTypes', ),
+ '1.2.840.113549.1.9.23.1': ('x509Crl', ),
+ '1.2.840.113549.1.12': ('pkcs12', ),
+ '1.2.840.113549.1.12.1': ('pkcs12-pbeids', ),
+ '1.2.840.113549.1.12.1.1': ('pbeWithSHA1And128BitRC4', 'PBE-SHA1-RC4-128'),
+ '1.2.840.113549.1.12.1.2': ('pbeWithSHA1And40BitRC4', 'PBE-SHA1-RC4-40'),
+ '1.2.840.113549.1.12.1.3': ('pbeWithSHA1And3-KeyTripleDES-CBC', 'PBE-SHA1-3DES'),
+ '1.2.840.113549.1.12.1.4': ('pbeWithSHA1And2-KeyTripleDES-CBC', 'PBE-SHA1-2DES'),
+ '1.2.840.113549.1.12.1.5': ('pbeWithSHA1And128BitRC2-CBC', 'PBE-SHA1-RC2-128'),
+ '1.2.840.113549.1.12.1.6': ('pbeWithSHA1And40BitRC2-CBC', 'PBE-SHA1-RC2-40'),
+ '1.2.840.113549.1.12.10': ('pkcs12-Version1', ),
+ '1.2.840.113549.1.12.10.1': ('pkcs12-BagIds', ),
+ '1.2.840.113549.1.12.10.1.1': ('keyBag', ),
+ '1.2.840.113549.1.12.10.1.2': ('pkcs8ShroudedKeyBag', ),
+ '1.2.840.113549.1.12.10.1.3': ('certBag', ),
+ '1.2.840.113549.1.12.10.1.4': ('crlBag', ),
+ '1.2.840.113549.1.12.10.1.5': ('secretBag', ),
+ '1.2.840.113549.1.12.10.1.6': ('safeContentsBag', ),
+ '1.2.840.113549.2.2': ('md2', 'MD2'),
+ '1.2.840.113549.2.4': ('md4', 'MD4'),
+ '1.2.840.113549.2.5': ('md5', 'MD5'),
+ '1.2.840.113549.2.6': ('hmacWithMD5', ),
+ '1.2.840.113549.2.7': ('hmacWithSHA1', ),
+ '1.2.840.113549.2.8': ('hmacWithSHA224', ),
+ '1.2.840.113549.2.9': ('hmacWithSHA256', ),
+ '1.2.840.113549.2.10': ('hmacWithSHA384', ),
+ '1.2.840.113549.2.11': ('hmacWithSHA512', ),
+ '1.2.840.113549.2.12': ('hmacWithSHA512-224', ),
+ '1.2.840.113549.2.13': ('hmacWithSHA512-256', ),
+ '1.2.840.113549.3.2': ('rc2-cbc', 'RC2-CBC'),
+ '1.2.840.113549.3.4': ('rc4', 'RC4'),
+ '1.2.840.113549.3.7': ('des-ede3-cbc', 'DES-EDE3-CBC'),
+ '1.2.840.113549.3.8': ('rc5-cbc', 'RC5-CBC'),
+ '1.2.840.113549.3.10': ('des-cdmf', 'DES-CDMF'),
+ '1.3': ('identified-organization', 'org', 'ORG'),
+ '1.3.6': ('dod', 'DOD'),
+ '1.3.6.1': ('iana', 'IANA', 'internet'),
+ '1.3.6.1.1': ('Directory', 'directory'),
+ '1.3.6.1.2': ('Management', 'mgmt'),
+ '1.3.6.1.3': ('Experimental', 'experimental'),
+ '1.3.6.1.4': ('Private', 'private'),
+ '1.3.6.1.4.1': ('Enterprises', 'enterprises'),
+ '1.3.6.1.4.1.188.7.1.1.2': ('idea-cbc', 'IDEA-CBC'),
+ '1.3.6.1.4.1.311.2.1.14': ('Microsoft Extension Request', 'msExtReq'),
+ '1.3.6.1.4.1.311.2.1.21': ('Microsoft Individual Code Signing', 'msCodeInd'),
+ '1.3.6.1.4.1.311.2.1.22': ('Microsoft Commercial Code Signing', 'msCodeCom'),
+ '1.3.6.1.4.1.311.10.3.1': ('Microsoft Trust List Signing', 'msCTLSign'),
+ '1.3.6.1.4.1.311.10.3.3': ('Microsoft Server Gated Crypto', 'msSGC'),
+ '1.3.6.1.4.1.311.10.3.4': ('Microsoft Encrypted File System', 'msEFS'),
+ '1.3.6.1.4.1.311.17.1': ('Microsoft CSP Name', 'CSPName'),
+ '1.3.6.1.4.1.311.17.2': ('Microsoft Local Key set', 'LocalKeySet'),
+ '1.3.6.1.4.1.311.20.2.2': ('Microsoft Smartcardlogin', 'msSmartcardLogin'),
+ '1.3.6.1.4.1.311.20.2.3': ('Microsoft Universal Principal Name', 'msUPN'),
+ '1.3.6.1.4.1.311.60.2.1.1': ('jurisdictionLocalityName', 'jurisdictionL'),
+ '1.3.6.1.4.1.311.60.2.1.2': ('jurisdictionStateOrProvinceName', 'jurisdictionST'),
+ '1.3.6.1.4.1.311.60.2.1.3': ('jurisdictionCountryName', 'jurisdictionC'),
+ '1.3.6.1.4.1.1466.344': ('dcObject', 'dcobject'),
+ '1.3.6.1.4.1.1722.12.2.1.16': ('blake2b512', 'BLAKE2b512'),
+ '1.3.6.1.4.1.1722.12.2.2.8': ('blake2s256', 'BLAKE2s256'),
+ '1.3.6.1.4.1.3029.1.2': ('bf-cbc', 'BF-CBC'),
+ '1.3.6.1.4.1.11129.2.4.2': ('CT Precertificate SCTs', 'ct_precert_scts'),
+ '1.3.6.1.4.1.11129.2.4.3': ('CT Precertificate Poison', 'ct_precert_poison'),
+ '1.3.6.1.4.1.11129.2.4.4': ('CT Precertificate Signer', 'ct_precert_signer'),
+ '1.3.6.1.4.1.11129.2.4.5': ('CT Certificate SCTs', 'ct_cert_scts'),
+ '1.3.6.1.4.1.11591.4.11': ('scrypt', 'id-scrypt'),
+ '1.3.6.1.5': ('Security', 'security'),
+ '1.3.6.1.5.2.3': ('id-pkinit', ),
+ '1.3.6.1.5.2.3.4': ('PKINIT Client Auth', 'pkInitClientAuth'),
+ '1.3.6.1.5.2.3.5': ('Signing KDC Response', 'pkInitKDC'),
+ '1.3.6.1.5.5.7': ('PKIX', ),
+ '1.3.6.1.5.5.7.0': ('id-pkix-mod', ),
+ '1.3.6.1.5.5.7.0.1': ('id-pkix1-explicit-88', ),
+ '1.3.6.1.5.5.7.0.2': ('id-pkix1-implicit-88', ),
+ '1.3.6.1.5.5.7.0.3': ('id-pkix1-explicit-93', ),
+ '1.3.6.1.5.5.7.0.4': ('id-pkix1-implicit-93', ),
+ '1.3.6.1.5.5.7.0.5': ('id-mod-crmf', ),
+ '1.3.6.1.5.5.7.0.6': ('id-mod-cmc', ),
+ '1.3.6.1.5.5.7.0.7': ('id-mod-kea-profile-88', ),
+ '1.3.6.1.5.5.7.0.8': ('id-mod-kea-profile-93', ),
+ '1.3.6.1.5.5.7.0.9': ('id-mod-cmp', ),
+ '1.3.6.1.5.5.7.0.10': ('id-mod-qualified-cert-88', ),
+ '1.3.6.1.5.5.7.0.11': ('id-mod-qualified-cert-93', ),
+ '1.3.6.1.5.5.7.0.12': ('id-mod-attribute-cert', ),
+ '1.3.6.1.5.5.7.0.13': ('id-mod-timestamp-protocol', ),
+ '1.3.6.1.5.5.7.0.14': ('id-mod-ocsp', ),
+ '1.3.6.1.5.5.7.0.15': ('id-mod-dvcs', ),
+ '1.3.6.1.5.5.7.0.16': ('id-mod-cmp2000', ),
+ '1.3.6.1.5.5.7.1': ('id-pe', ),
+ '1.3.6.1.5.5.7.1.1': ('Authority Information Access', 'authorityInfoAccess'),
+ '1.3.6.1.5.5.7.1.2': ('Biometric Info', 'biometricInfo'),
+ '1.3.6.1.5.5.7.1.3': ('qcStatements', ),
+ '1.3.6.1.5.5.7.1.4': ('ac-auditEntity', ),
+ '1.3.6.1.5.5.7.1.5': ('ac-targeting', ),
+ '1.3.6.1.5.5.7.1.6': ('aaControls', ),
+ '1.3.6.1.5.5.7.1.7': ('sbgp-ipAddrBlock', ),
+ '1.3.6.1.5.5.7.1.8': ('sbgp-autonomousSysNum', ),
+ '1.3.6.1.5.5.7.1.9': ('sbgp-routerIdentifier', ),
+ '1.3.6.1.5.5.7.1.10': ('ac-proxying', ),
+ '1.3.6.1.5.5.7.1.11': ('Subject Information Access', 'subjectInfoAccess'),
+ '1.3.6.1.5.5.7.1.14': ('Proxy Certificate Information', 'proxyCertInfo'),
+ '1.3.6.1.5.5.7.1.24': ('TLS Feature', 'tlsfeature'),
+ '1.3.6.1.5.5.7.2': ('id-qt', ),
+ '1.3.6.1.5.5.7.2.1': ('Policy Qualifier CPS', 'id-qt-cps'),
+ '1.3.6.1.5.5.7.2.2': ('Policy Qualifier User Notice', 'id-qt-unotice'),
+ '1.3.6.1.5.5.7.2.3': ('textNotice', ),
+ '1.3.6.1.5.5.7.3': ('id-kp', ),
+ '1.3.6.1.5.5.7.3.1': ('TLS Web Server Authentication', 'serverAuth'),
+ '1.3.6.1.5.5.7.3.2': ('TLS Web Client Authentication', 'clientAuth'),
+ '1.3.6.1.5.5.7.3.3': ('Code Signing', 'codeSigning'),
+ '1.3.6.1.5.5.7.3.4': ('E-mail Protection', 'emailProtection'),
+ '1.3.6.1.5.5.7.3.5': ('IPSec End System', 'ipsecEndSystem'),
+ '1.3.6.1.5.5.7.3.6': ('IPSec Tunnel', 'ipsecTunnel'),
+ '1.3.6.1.5.5.7.3.7': ('IPSec User', 'ipsecUser'),
+ '1.3.6.1.5.5.7.3.8': ('Time Stamping', 'timeStamping'),
+ '1.3.6.1.5.5.7.3.9': ('OCSP Signing', 'OCSPSigning'),
+ '1.3.6.1.5.5.7.3.10': ('dvcs', 'DVCS'),
+ '1.3.6.1.5.5.7.3.17': ('ipsec Internet Key Exchange', 'ipsecIKE'),
+ '1.3.6.1.5.5.7.3.18': ('Ctrl/provision WAP Access', 'capwapAC'),
+ '1.3.6.1.5.5.7.3.19': ('Ctrl/Provision WAP Termination', 'capwapWTP'),
+ '1.3.6.1.5.5.7.3.21': ('SSH Client', 'secureShellClient'),
+ '1.3.6.1.5.5.7.3.22': ('SSH Server', 'secureShellServer'),
+ '1.3.6.1.5.5.7.3.23': ('Send Router', 'sendRouter'),
+ '1.3.6.1.5.5.7.3.24': ('Send Proxied Router', 'sendProxiedRouter'),
+ '1.3.6.1.5.5.7.3.25': ('Send Owner', 'sendOwner'),
+ '1.3.6.1.5.5.7.3.26': ('Send Proxied Owner', 'sendProxiedOwner'),
+ '1.3.6.1.5.5.7.3.27': ('CMC Certificate Authority', 'cmcCA'),
+ '1.3.6.1.5.5.7.3.28': ('CMC Registration Authority', 'cmcRA'),
+ '1.3.6.1.5.5.7.4': ('id-it', ),
+ '1.3.6.1.5.5.7.4.1': ('id-it-caProtEncCert', ),
+ '1.3.6.1.5.5.7.4.2': ('id-it-signKeyPairTypes', ),
+ '1.3.6.1.5.5.7.4.3': ('id-it-encKeyPairTypes', ),
+ '1.3.6.1.5.5.7.4.4': ('id-it-preferredSymmAlg', ),
+ '1.3.6.1.5.5.7.4.5': ('id-it-caKeyUpdateInfo', ),
+ '1.3.6.1.5.5.7.4.6': ('id-it-currentCRL', ),
+ '1.3.6.1.5.5.7.4.7': ('id-it-unsupportedOIDs', ),
+ '1.3.6.1.5.5.7.4.8': ('id-it-subscriptionRequest', ),
+ '1.3.6.1.5.5.7.4.9': ('id-it-subscriptionResponse', ),
+ '1.3.6.1.5.5.7.4.10': ('id-it-keyPairParamReq', ),
+ '1.3.6.1.5.5.7.4.11': ('id-it-keyPairParamRep', ),
+ '1.3.6.1.5.5.7.4.12': ('id-it-revPassphrase', ),
+ '1.3.6.1.5.5.7.4.13': ('id-it-implicitConfirm', ),
+ '1.3.6.1.5.5.7.4.14': ('id-it-confirmWaitTime', ),
+ '1.3.6.1.5.5.7.4.15': ('id-it-origPKIMessage', ),
+ '1.3.6.1.5.5.7.4.16': ('id-it-suppLangTags', ),
+ '1.3.6.1.5.5.7.5': ('id-pkip', ),
+ '1.3.6.1.5.5.7.5.1': ('id-regCtrl', ),
+ '1.3.6.1.5.5.7.5.1.1': ('id-regCtrl-regToken', ),
+ '1.3.6.1.5.5.7.5.1.2': ('id-regCtrl-authenticator', ),
+ '1.3.6.1.5.5.7.5.1.3': ('id-regCtrl-pkiPublicationInfo', ),
+ '1.3.6.1.5.5.7.5.1.4': ('id-regCtrl-pkiArchiveOptions', ),
+ '1.3.6.1.5.5.7.5.1.5': ('id-regCtrl-oldCertID', ),
+ '1.3.6.1.5.5.7.5.1.6': ('id-regCtrl-protocolEncrKey', ),
+ '1.3.6.1.5.5.7.5.2': ('id-regInfo', ),
+ '1.3.6.1.5.5.7.5.2.1': ('id-regInfo-utf8Pairs', ),
+ '1.3.6.1.5.5.7.5.2.2': ('id-regInfo-certReq', ),
+ '1.3.6.1.5.5.7.6': ('id-alg', ),
+ '1.3.6.1.5.5.7.6.1': ('id-alg-des40', ),
+ '1.3.6.1.5.5.7.6.2': ('id-alg-noSignature', ),
+ '1.3.6.1.5.5.7.6.3': ('id-alg-dh-sig-hmac-sha1', ),
+ '1.3.6.1.5.5.7.6.4': ('id-alg-dh-pop', ),
+ '1.3.6.1.5.5.7.7': ('id-cmc', ),
+ '1.3.6.1.5.5.7.7.1': ('id-cmc-statusInfo', ),
+ '1.3.6.1.5.5.7.7.2': ('id-cmc-identification', ),
+ '1.3.6.1.5.5.7.7.3': ('id-cmc-identityProof', ),
+ '1.3.6.1.5.5.7.7.4': ('id-cmc-dataReturn', ),
+ '1.3.6.1.5.5.7.7.5': ('id-cmc-transactionId', ),
+ '1.3.6.1.5.5.7.7.6': ('id-cmc-senderNonce', ),
+ '1.3.6.1.5.5.7.7.7': ('id-cmc-recipientNonce', ),
+ '1.3.6.1.5.5.7.7.8': ('id-cmc-addExtensions', ),
+ '1.3.6.1.5.5.7.7.9': ('id-cmc-encryptedPOP', ),
+ '1.3.6.1.5.5.7.7.10': ('id-cmc-decryptedPOP', ),
+ '1.3.6.1.5.5.7.7.11': ('id-cmc-lraPOPWitness', ),
+ '1.3.6.1.5.5.7.7.15': ('id-cmc-getCert', ),
+ '1.3.6.1.5.5.7.7.16': ('id-cmc-getCRL', ),
+ '1.3.6.1.5.5.7.7.17': ('id-cmc-revokeRequest', ),
+ '1.3.6.1.5.5.7.7.18': ('id-cmc-regInfo', ),
+ '1.3.6.1.5.5.7.7.19': ('id-cmc-responseInfo', ),
+ '1.3.6.1.5.5.7.7.21': ('id-cmc-queryPending', ),
+ '1.3.6.1.5.5.7.7.22': ('id-cmc-popLinkRandom', ),
+ '1.3.6.1.5.5.7.7.23': ('id-cmc-popLinkWitness', ),
+ '1.3.6.1.5.5.7.7.24': ('id-cmc-confirmCertAcceptance', ),
+ '1.3.6.1.5.5.7.8': ('id-on', ),
+ '1.3.6.1.5.5.7.8.1': ('id-on-personalData', ),
+ '1.3.6.1.5.5.7.8.3': ('Permanent Identifier', 'id-on-permanentIdentifier'),
+ '1.3.6.1.5.5.7.9': ('id-pda', ),
+ '1.3.6.1.5.5.7.9.1': ('id-pda-dateOfBirth', ),
+ '1.3.6.1.5.5.7.9.2': ('id-pda-placeOfBirth', ),
+ '1.3.6.1.5.5.7.9.3': ('id-pda-gender', ),
+ '1.3.6.1.5.5.7.9.4': ('id-pda-countryOfCitizenship', ),
+ '1.3.6.1.5.5.7.9.5': ('id-pda-countryOfResidence', ),
+ '1.3.6.1.5.5.7.10': ('id-aca', ),
+ '1.3.6.1.5.5.7.10.1': ('id-aca-authenticationInfo', ),
+ '1.3.6.1.5.5.7.10.2': ('id-aca-accessIdentity', ),
+ '1.3.6.1.5.5.7.10.3': ('id-aca-chargingIdentity', ),
+ '1.3.6.1.5.5.7.10.4': ('id-aca-group', ),
+ '1.3.6.1.5.5.7.10.5': ('id-aca-role', ),
+ '1.3.6.1.5.5.7.10.6': ('id-aca-encAttrs', ),
+ '1.3.6.1.5.5.7.11': ('id-qcs', ),
+ '1.3.6.1.5.5.7.11.1': ('id-qcs-pkixQCSyntax-v1', ),
+ '1.3.6.1.5.5.7.12': ('id-cct', ),
+ '1.3.6.1.5.5.7.12.1': ('id-cct-crs', ),
+ '1.3.6.1.5.5.7.12.2': ('id-cct-PKIData', ),
+ '1.3.6.1.5.5.7.12.3': ('id-cct-PKIResponse', ),
+ '1.3.6.1.5.5.7.21': ('id-ppl', ),
+ '1.3.6.1.5.5.7.21.0': ('Any language', 'id-ppl-anyLanguage'),
+ '1.3.6.1.5.5.7.21.1': ('Inherit all', 'id-ppl-inheritAll'),
+ '1.3.6.1.5.5.7.21.2': ('Independent', 'id-ppl-independent'),
+ '1.3.6.1.5.5.7.48': ('id-ad', ),
+ '1.3.6.1.5.5.7.48.1': ('OCSP', 'OCSP', 'id-pkix-OCSP'),
+ '1.3.6.1.5.5.7.48.1.1': ('Basic OCSP Response', 'basicOCSPResponse'),
+ '1.3.6.1.5.5.7.48.1.2': ('OCSP Nonce', 'Nonce'),
+ '1.3.6.1.5.5.7.48.1.3': ('OCSP CRL ID', 'CrlID'),
+ '1.3.6.1.5.5.7.48.1.4': ('Acceptable OCSP Responses', 'acceptableResponses'),
+ '1.3.6.1.5.5.7.48.1.5': ('OCSP No Check', 'noCheck'),
+ '1.3.6.1.5.5.7.48.1.6': ('OCSP Archive Cutoff', 'archiveCutoff'),
+ '1.3.6.1.5.5.7.48.1.7': ('OCSP Service Locator', 'serviceLocator'),
+ '1.3.6.1.5.5.7.48.1.8': ('Extended OCSP Status', 'extendedStatus'),
+ '1.3.6.1.5.5.7.48.1.9': ('valid', ),
+ '1.3.6.1.5.5.7.48.1.10': ('path', ),
+ '1.3.6.1.5.5.7.48.1.11': ('Trust Root', 'trustRoot'),
+ '1.3.6.1.5.5.7.48.2': ('CA Issuers', 'caIssuers'),
+ '1.3.6.1.5.5.7.48.3': ('AD Time Stamping', 'ad_timestamping'),
+ '1.3.6.1.5.5.7.48.4': ('ad dvcs', 'AD_DVCS'),
+ '1.3.6.1.5.5.7.48.5': ('CA Repository', 'caRepository'),
+ '1.3.6.1.5.5.8.1.1': ('hmac-md5', 'HMAC-MD5'),
+ '1.3.6.1.5.5.8.1.2': ('hmac-sha1', 'HMAC-SHA1'),
+ '1.3.6.1.6': ('SNMPv2', 'snmpv2'),
+ '1.3.6.1.7': ('Mail', ),
+ '1.3.6.1.7.1': ('MIME MHS', 'mime-mhs'),
+ '1.3.6.1.7.1.1': ('mime-mhs-headings', 'mime-mhs-headings'),
+ '1.3.6.1.7.1.1.1': ('id-hex-partial-message', 'id-hex-partial-message'),
+ '1.3.6.1.7.1.1.2': ('id-hex-multipart-message', 'id-hex-multipart-message'),
+ '1.3.6.1.7.1.2': ('mime-mhs-bodies', 'mime-mhs-bodies'),
+ '1.3.14.3.2': ('algorithm', 'algorithm'),
+ '1.3.14.3.2.3': ('md5WithRSA', 'RSA-NP-MD5'),
+ '1.3.14.3.2.6': ('des-ecb', 'DES-ECB'),
+ '1.3.14.3.2.7': ('des-cbc', 'DES-CBC'),
+ '1.3.14.3.2.8': ('des-ofb', 'DES-OFB'),
+ '1.3.14.3.2.9': ('des-cfb', 'DES-CFB'),
+ '1.3.14.3.2.11': ('rsaSignature', ),
+ '1.3.14.3.2.12': ('dsaEncryption-old', 'DSA-old'),
+ '1.3.14.3.2.13': ('dsaWithSHA', 'DSA-SHA'),
+ '1.3.14.3.2.15': ('shaWithRSAEncryption', 'RSA-SHA'),
+ '1.3.14.3.2.17': ('des-ede', 'DES-EDE'),
+ '1.3.14.3.2.18': ('sha', 'SHA'),
+ '1.3.14.3.2.26': ('sha1', 'SHA1'),
+ '1.3.14.3.2.27': ('dsaWithSHA1-old', 'DSA-SHA1-old'),
+ '1.3.14.3.2.29': ('sha1WithRSA', 'RSA-SHA1-2'),
+ '1.3.36.3.2.1': ('ripemd160', 'RIPEMD160'),
+ '1.3.36.3.3.1.2': ('ripemd160WithRSA', 'RSA-RIPEMD160'),
+ '1.3.36.3.3.2.8.1.1.1': ('brainpoolP160r1', ),
+ '1.3.36.3.3.2.8.1.1.2': ('brainpoolP160t1', ),
+ '1.3.36.3.3.2.8.1.1.3': ('brainpoolP192r1', ),
+ '1.3.36.3.3.2.8.1.1.4': ('brainpoolP192t1', ),
+ '1.3.36.3.3.2.8.1.1.5': ('brainpoolP224r1', ),
+ '1.3.36.3.3.2.8.1.1.6': ('brainpoolP224t1', ),
+ '1.3.36.3.3.2.8.1.1.7': ('brainpoolP256r1', ),
+ '1.3.36.3.3.2.8.1.1.8': ('brainpoolP256t1', ),
+ '1.3.36.3.3.2.8.1.1.9': ('brainpoolP320r1', ),
+ '1.3.36.3.3.2.8.1.1.10': ('brainpoolP320t1', ),
+ '1.3.36.3.3.2.8.1.1.11': ('brainpoolP384r1', ),
+ '1.3.36.3.3.2.8.1.1.12': ('brainpoolP384t1', ),
+ '1.3.36.3.3.2.8.1.1.13': ('brainpoolP512r1', ),
+ '1.3.36.3.3.2.8.1.1.14': ('brainpoolP512t1', ),
+ '1.3.36.8.3.3': ('Professional Information or basis for Admission', 'x509ExtAdmission'),
+ '1.3.101.1.4.1': ('Strong Extranet ID', 'SXNetID'),
+ '1.3.101.110': ('X25519', ),
+ '1.3.101.111': ('X448', ),
+ '1.3.101.112': ('ED25519', ),
+ '1.3.101.113': ('ED448', ),
+ '1.3.111': ('ieee', ),
+ '1.3.111.2.1619': ('IEEE Security in Storage Working Group', 'ieee-siswg'),
+ '1.3.111.2.1619.0.1.1': ('aes-128-xts', 'AES-128-XTS'),
+ '1.3.111.2.1619.0.1.2': ('aes-256-xts', 'AES-256-XTS'),
+ '1.3.132': ('certicom-arc', ),
+ '1.3.132.0': ('secg_ellipticCurve', ),
+ '1.3.132.0.1': ('sect163k1', ),
+ '1.3.132.0.2': ('sect163r1', ),
+ '1.3.132.0.3': ('sect239k1', ),
+ '1.3.132.0.4': ('sect113r1', ),
+ '1.3.132.0.5': ('sect113r2', ),
+ '1.3.132.0.6': ('secp112r1', ),
+ '1.3.132.0.7': ('secp112r2', ),
+ '1.3.132.0.8': ('secp160r1', ),
+ '1.3.132.0.9': ('secp160k1', ),
+ '1.3.132.0.10': ('secp256k1', ),
+ '1.3.132.0.15': ('sect163r2', ),
+ '1.3.132.0.16': ('sect283k1', ),
+ '1.3.132.0.17': ('sect283r1', ),
+ '1.3.132.0.22': ('sect131r1', ),
+ '1.3.132.0.23': ('sect131r2', ),
+ '1.3.132.0.24': ('sect193r1', ),
+ '1.3.132.0.25': ('sect193r2', ),
+ '1.3.132.0.26': ('sect233k1', ),
+ '1.3.132.0.27': ('sect233r1', ),
+ '1.3.132.0.28': ('secp128r1', ),
+ '1.3.132.0.29': ('secp128r2', ),
+ '1.3.132.0.30': ('secp160r2', ),
+ '1.3.132.0.31': ('secp192k1', ),
+ '1.3.132.0.32': ('secp224k1', ),
+ '1.3.132.0.33': ('secp224r1', ),
+ '1.3.132.0.34': ('secp384r1', ),
+ '1.3.132.0.35': ('secp521r1', ),
+ '1.3.132.0.36': ('sect409k1', ),
+ '1.3.132.0.37': ('sect409r1', ),
+ '1.3.132.0.38': ('sect571k1', ),
+ '1.3.132.0.39': ('sect571r1', ),
+ '1.3.132.1': ('secg-scheme', ),
+ '1.3.132.1.11.0': ('dhSinglePass-stdDH-sha224kdf-scheme', ),
+ '1.3.132.1.11.1': ('dhSinglePass-stdDH-sha256kdf-scheme', ),
+ '1.3.132.1.11.2': ('dhSinglePass-stdDH-sha384kdf-scheme', ),
+ '1.3.132.1.11.3': ('dhSinglePass-stdDH-sha512kdf-scheme', ),
+ '1.3.132.1.14.0': ('dhSinglePass-cofactorDH-sha224kdf-scheme', ),
+ '1.3.132.1.14.1': ('dhSinglePass-cofactorDH-sha256kdf-scheme', ),
+ '1.3.132.1.14.2': ('dhSinglePass-cofactorDH-sha384kdf-scheme', ),
+ '1.3.132.1.14.3': ('dhSinglePass-cofactorDH-sha512kdf-scheme', ),
+ '1.3.133.16.840.63.0': ('x9-63-scheme', ),
+ '1.3.133.16.840.63.0.2': ('dhSinglePass-stdDH-sha1kdf-scheme', ),
+ '1.3.133.16.840.63.0.3': ('dhSinglePass-cofactorDH-sha1kdf-scheme', ),
+ '2': ('joint-iso-itu-t', 'JOINT-ISO-ITU-T', 'joint-iso-ccitt'),
+ '2.5': ('directory services (X.500)', 'X500'),
+ '2.5.1.5': ('Selected Attribute Types', 'selected-attribute-types'),
+ '2.5.1.5.55': ('clearance', ),
+ '2.5.4': ('X509', ),
+ '2.5.4.3': ('commonName', 'CN'),
+ '2.5.4.4': ('surname', 'SN'),
+ '2.5.4.5': ('serialNumber', ),
+ '2.5.4.6': ('countryName', 'C'),
+ '2.5.4.7': ('localityName', 'L'),
+ '2.5.4.8': ('stateOrProvinceName', 'ST'),
+ '2.5.4.9': ('streetAddress', 'street'),
+ '2.5.4.10': ('organizationName', 'O'),
+ '2.5.4.11': ('organizationalUnitName', 'OU'),
+ '2.5.4.12': ('title', 'title'),
+ '2.5.4.13': ('description', ),
+ '2.5.4.14': ('searchGuide', ),
+ '2.5.4.15': ('businessCategory', ),
+ '2.5.4.16': ('postalAddress', ),
+ '2.5.4.17': ('postalCode', ),
+ '2.5.4.18': ('postOfficeBox', ),
+ '2.5.4.19': ('physicalDeliveryOfficeName', ),
+ '2.5.4.20': ('telephoneNumber', ),
+ '2.5.4.21': ('telexNumber', ),
+ '2.5.4.22': ('teletexTerminalIdentifier', ),
+ '2.5.4.23': ('facsimileTelephoneNumber', ),
+ '2.5.4.24': ('x121Address', ),
+ '2.5.4.25': ('internationaliSDNNumber', ),
+ '2.5.4.26': ('registeredAddress', ),
+ '2.5.4.27': ('destinationIndicator', ),
+ '2.5.4.28': ('preferredDeliveryMethod', ),
+ '2.5.4.29': ('presentationAddress', ),
+ '2.5.4.30': ('supportedApplicationContext', ),
+ '2.5.4.31': ('member', ),
+ '2.5.4.32': ('owner', ),
+ '2.5.4.33': ('roleOccupant', ),
+ '2.5.4.34': ('seeAlso', ),
+ '2.5.4.35': ('userPassword', ),
+ '2.5.4.36': ('userCertificate', ),
+ '2.5.4.37': ('cACertificate', ),
+ '2.5.4.38': ('authorityRevocationList', ),
+ '2.5.4.39': ('certificateRevocationList', ),
+ '2.5.4.40': ('crossCertificatePair', ),
+ '2.5.4.41': ('name', 'name'),
+ '2.5.4.42': ('givenName', 'GN'),
+ '2.5.4.43': ('initials', 'initials'),
+ '2.5.4.44': ('generationQualifier', ),
+ '2.5.4.45': ('x500UniqueIdentifier', ),
+ '2.5.4.46': ('dnQualifier', 'dnQualifier'),
+ '2.5.4.47': ('enhancedSearchGuide', ),
+ '2.5.4.48': ('protocolInformation', ),
+ '2.5.4.49': ('distinguishedName', ),
+ '2.5.4.50': ('uniqueMember', ),
+ '2.5.4.51': ('houseIdentifier', ),
+ '2.5.4.52': ('supportedAlgorithms', ),
+ '2.5.4.53': ('deltaRevocationList', ),
+ '2.5.4.54': ('dmdName', ),
+ '2.5.4.65': ('pseudonym', ),
+ '2.5.4.72': ('role', 'role'),
+ '2.5.4.97': ('organizationIdentifier', ),
+ '2.5.4.98': ('countryCode3c', 'c3'),
+ '2.5.4.99': ('countryCode3n', 'n3'),
+ '2.5.4.100': ('dnsName', ),
+ '2.5.8': ('directory services - algorithms', 'X500algorithms'),
+ '2.5.8.1.1': ('rsa', 'RSA'),
+ '2.5.8.3.100': ('mdc2WithRSA', 'RSA-MDC2'),
+ '2.5.8.3.101': ('mdc2', 'MDC2'),
+ '2.5.29': ('id-ce', ),
+ '2.5.29.9': ('X509v3 Subject Directory Attributes', 'subjectDirectoryAttributes'),
+ '2.5.29.14': ('X509v3 Subject Key Identifier', 'subjectKeyIdentifier'),
+ '2.5.29.15': ('X509v3 Key Usage', 'keyUsage'),
+ '2.5.29.16': ('X509v3 Private Key Usage Period', 'privateKeyUsagePeriod'),
+ '2.5.29.17': ('X509v3 Subject Alternative Name', 'subjectAltName'),
+ '2.5.29.18': ('X509v3 Issuer Alternative Name', 'issuerAltName'),
+ '2.5.29.19': ('X509v3 Basic Constraints', 'basicConstraints'),
+ '2.5.29.20': ('X509v3 CRL Number', 'crlNumber'),
+ '2.5.29.21': ('X509v3 CRL Reason Code', 'CRLReason'),
+ '2.5.29.23': ('Hold Instruction Code', 'holdInstructionCode'),
+ '2.5.29.24': ('Invalidity Date', 'invalidityDate'),
+ '2.5.29.27': ('X509v3 Delta CRL Indicator', 'deltaCRL'),
+ '2.5.29.28': ('X509v3 Issuing Distribution Point', 'issuingDistributionPoint'),
+ '2.5.29.29': ('X509v3 Certificate Issuer', 'certificateIssuer'),
+ '2.5.29.30': ('X509v3 Name Constraints', 'nameConstraints'),
+ '2.5.29.31': ('X509v3 CRL Distribution Points', 'crlDistributionPoints'),
+ '2.5.29.32': ('X509v3 Certificate Policies', 'certificatePolicies'),
+ '2.5.29.32.0': ('X509v3 Any Policy', 'anyPolicy'),
+ '2.5.29.33': ('X509v3 Policy Mappings', 'policyMappings'),
+ '2.5.29.35': ('X509v3 Authority Key Identifier', 'authorityKeyIdentifier'),
+ '2.5.29.36': ('X509v3 Policy Constraints', 'policyConstraints'),
+ '2.5.29.37': ('X509v3 Extended Key Usage', 'extendedKeyUsage'),
+ '2.5.29.37.0': ('Any Extended Key Usage', 'anyExtendedKeyUsage'),
+ '2.5.29.46': ('X509v3 Freshest CRL', 'freshestCRL'),
+ '2.5.29.54': ('X509v3 Inhibit Any Policy', 'inhibitAnyPolicy'),
+ '2.5.29.55': ('X509v3 AC Targeting', 'targetInformation'),
+ '2.5.29.56': ('X509v3 No Revocation Available', 'noRevAvail'),
+ '2.16.840.1.101.3': ('csor', ),
+ '2.16.840.1.101.3.4': ('nistAlgorithms', ),
+ '2.16.840.1.101.3.4.1': ('aes', ),
+ '2.16.840.1.101.3.4.1.1': ('aes-128-ecb', 'AES-128-ECB'),
+ '2.16.840.1.101.3.4.1.2': ('aes-128-cbc', 'AES-128-CBC'),
+ '2.16.840.1.101.3.4.1.3': ('aes-128-ofb', 'AES-128-OFB'),
+ '2.16.840.1.101.3.4.1.4': ('aes-128-cfb', 'AES-128-CFB'),
+ '2.16.840.1.101.3.4.1.5': ('id-aes128-wrap', ),
+ '2.16.840.1.101.3.4.1.6': ('aes-128-gcm', 'id-aes128-GCM'),
+ '2.16.840.1.101.3.4.1.7': ('aes-128-ccm', 'id-aes128-CCM'),
+ '2.16.840.1.101.3.4.1.8': ('id-aes128-wrap-pad', ),
+ '2.16.840.1.101.3.4.1.21': ('aes-192-ecb', 'AES-192-ECB'),
+ '2.16.840.1.101.3.4.1.22': ('aes-192-cbc', 'AES-192-CBC'),
+ '2.16.840.1.101.3.4.1.23': ('aes-192-ofb', 'AES-192-OFB'),
+ '2.16.840.1.101.3.4.1.24': ('aes-192-cfb', 'AES-192-CFB'),
+ '2.16.840.1.101.3.4.1.25': ('id-aes192-wrap', ),
+ '2.16.840.1.101.3.4.1.26': ('aes-192-gcm', 'id-aes192-GCM'),
+ '2.16.840.1.101.3.4.1.27': ('aes-192-ccm', 'id-aes192-CCM'),
+ '2.16.840.1.101.3.4.1.28': ('id-aes192-wrap-pad', ),
+ '2.16.840.1.101.3.4.1.41': ('aes-256-ecb', 'AES-256-ECB'),
+ '2.16.840.1.101.3.4.1.42': ('aes-256-cbc', 'AES-256-CBC'),
+ '2.16.840.1.101.3.4.1.43': ('aes-256-ofb', 'AES-256-OFB'),
+ '2.16.840.1.101.3.4.1.44': ('aes-256-cfb', 'AES-256-CFB'),
+ '2.16.840.1.101.3.4.1.45': ('id-aes256-wrap', ),
+ '2.16.840.1.101.3.4.1.46': ('aes-256-gcm', 'id-aes256-GCM'),
+ '2.16.840.1.101.3.4.1.47': ('aes-256-ccm', 'id-aes256-CCM'),
+ '2.16.840.1.101.3.4.1.48': ('id-aes256-wrap-pad', ),
+ '2.16.840.1.101.3.4.2': ('nist_hashalgs', ),
+ '2.16.840.1.101.3.4.2.1': ('sha256', 'SHA256'),
+ '2.16.840.1.101.3.4.2.2': ('sha384', 'SHA384'),
+ '2.16.840.1.101.3.4.2.3': ('sha512', 'SHA512'),
+ '2.16.840.1.101.3.4.2.4': ('sha224', 'SHA224'),
+ '2.16.840.1.101.3.4.2.5': ('sha512-224', 'SHA512-224'),
+ '2.16.840.1.101.3.4.2.6': ('sha512-256', 'SHA512-256'),
+ '2.16.840.1.101.3.4.2.7': ('sha3-224', 'SHA3-224'),
+ '2.16.840.1.101.3.4.2.8': ('sha3-256', 'SHA3-256'),
+ '2.16.840.1.101.3.4.2.9': ('sha3-384', 'SHA3-384'),
+ '2.16.840.1.101.3.4.2.10': ('sha3-512', 'SHA3-512'),
+ '2.16.840.1.101.3.4.2.11': ('shake128', 'SHAKE128'),
+ '2.16.840.1.101.3.4.2.12': ('shake256', 'SHAKE256'),
+ '2.16.840.1.101.3.4.2.13': ('hmac-sha3-224', 'id-hmacWithSHA3-224'),
+ '2.16.840.1.101.3.4.2.14': ('hmac-sha3-256', 'id-hmacWithSHA3-256'),
+ '2.16.840.1.101.3.4.2.15': ('hmac-sha3-384', 'id-hmacWithSHA3-384'),
+ '2.16.840.1.101.3.4.2.16': ('hmac-sha3-512', 'id-hmacWithSHA3-512'),
+ '2.16.840.1.101.3.4.3': ('dsa_with_sha2', 'sigAlgs'),
+ '2.16.840.1.101.3.4.3.1': ('dsa_with_SHA224', ),
+ '2.16.840.1.101.3.4.3.2': ('dsa_with_SHA256', ),
+ '2.16.840.1.101.3.4.3.3': ('dsa_with_SHA384', 'id-dsa-with-sha384'),
+ '2.16.840.1.101.3.4.3.4': ('dsa_with_SHA512', 'id-dsa-with-sha512'),
+ '2.16.840.1.101.3.4.3.5': ('dsa_with_SHA3-224', 'id-dsa-with-sha3-224'),
+ '2.16.840.1.101.3.4.3.6': ('dsa_with_SHA3-256', 'id-dsa-with-sha3-256'),
+ '2.16.840.1.101.3.4.3.7': ('dsa_with_SHA3-384', 'id-dsa-with-sha3-384'),
+ '2.16.840.1.101.3.4.3.8': ('dsa_with_SHA3-512', 'id-dsa-with-sha3-512'),
+ '2.16.840.1.101.3.4.3.9': ('ecdsa_with_SHA3-224', 'id-ecdsa-with-sha3-224'),
+ '2.16.840.1.101.3.4.3.10': ('ecdsa_with_SHA3-256', 'id-ecdsa-with-sha3-256'),
+ '2.16.840.1.101.3.4.3.11': ('ecdsa_with_SHA3-384', 'id-ecdsa-with-sha3-384'),
+ '2.16.840.1.101.3.4.3.12': ('ecdsa_with_SHA3-512', 'id-ecdsa-with-sha3-512'),
+ '2.16.840.1.101.3.4.3.13': ('RSA-SHA3-224', 'id-rsassa-pkcs1-v1_5-with-sha3-224'),
+ '2.16.840.1.101.3.4.3.14': ('RSA-SHA3-256', 'id-rsassa-pkcs1-v1_5-with-sha3-256'),
+ '2.16.840.1.101.3.4.3.15': ('RSA-SHA3-384', 'id-rsassa-pkcs1-v1_5-with-sha3-384'),
+ '2.16.840.1.101.3.4.3.16': ('RSA-SHA3-512', 'id-rsassa-pkcs1-v1_5-with-sha3-512'),
+ '2.16.840.1.113730': ('Netscape Communications Corp.', 'Netscape'),
+ '2.16.840.1.113730.1': ('Netscape Certificate Extension', 'nsCertExt'),
+ '2.16.840.1.113730.1.1': ('Netscape Cert Type', 'nsCertType'),
+ '2.16.840.1.113730.1.2': ('Netscape Base Url', 'nsBaseUrl'),
+ '2.16.840.1.113730.1.3': ('Netscape Revocation Url', 'nsRevocationUrl'),
+ '2.16.840.1.113730.1.4': ('Netscape CA Revocation Url', 'nsCaRevocationUrl'),
+ '2.16.840.1.113730.1.7': ('Netscape Renewal Url', 'nsRenewalUrl'),
+ '2.16.840.1.113730.1.8': ('Netscape CA Policy Url', 'nsCaPolicyUrl'),
+ '2.16.840.1.113730.1.12': ('Netscape SSL Server Name', 'nsSslServerName'),
+ '2.16.840.1.113730.1.13': ('Netscape Comment', 'nsComment'),
+ '2.16.840.1.113730.2': ('Netscape Data Type', 'nsDataType'),
+ '2.16.840.1.113730.2.5': ('Netscape Certificate Sequence', 'nsCertSequence'),
+ '2.16.840.1.113730.4.1': ('Netscape Server Gated Crypto', 'nsSGC'),
+ '2.23': ('International Organizations', 'international-organizations'),
+ '2.23.42': ('Secure Electronic Transactions', 'id-set'),
+ '2.23.42.0': ('content types', 'set-ctype'),
+ '2.23.42.0.0': ('setct-PANData', ),
+ '2.23.42.0.1': ('setct-PANToken', ),
+ '2.23.42.0.2': ('setct-PANOnly', ),
+ '2.23.42.0.3': ('setct-OIData', ),
+ '2.23.42.0.4': ('setct-PI', ),
+ '2.23.42.0.5': ('setct-PIData', ),
+ '2.23.42.0.6': ('setct-PIDataUnsigned', ),
+ '2.23.42.0.7': ('setct-HODInput', ),
+ '2.23.42.0.8': ('setct-AuthResBaggage', ),
+ '2.23.42.0.9': ('setct-AuthRevReqBaggage', ),
+ '2.23.42.0.10': ('setct-AuthRevResBaggage', ),
+ '2.23.42.0.11': ('setct-CapTokenSeq', ),
+ '2.23.42.0.12': ('setct-PInitResData', ),
+ '2.23.42.0.13': ('setct-PI-TBS', ),
+ '2.23.42.0.14': ('setct-PResData', ),
+ '2.23.42.0.16': ('setct-AuthReqTBS', ),
+ '2.23.42.0.17': ('setct-AuthResTBS', ),
+ '2.23.42.0.18': ('setct-AuthResTBSX', ),
+ '2.23.42.0.19': ('setct-AuthTokenTBS', ),
+ '2.23.42.0.20': ('setct-CapTokenData', ),
+ '2.23.42.0.21': ('setct-CapTokenTBS', ),
+ '2.23.42.0.22': ('setct-AcqCardCodeMsg', ),
+ '2.23.42.0.23': ('setct-AuthRevReqTBS', ),
+ '2.23.42.0.24': ('setct-AuthRevResData', ),
+ '2.23.42.0.25': ('setct-AuthRevResTBS', ),
+ '2.23.42.0.26': ('setct-CapReqTBS', ),
+ '2.23.42.0.27': ('setct-CapReqTBSX', ),
+ '2.23.42.0.28': ('setct-CapResData', ),
+ '2.23.42.0.29': ('setct-CapRevReqTBS', ),
+ '2.23.42.0.30': ('setct-CapRevReqTBSX', ),
+ '2.23.42.0.31': ('setct-CapRevResData', ),
+ '2.23.42.0.32': ('setct-CredReqTBS', ),
+ '2.23.42.0.33': ('setct-CredReqTBSX', ),
+ '2.23.42.0.34': ('setct-CredResData', ),
+ '2.23.42.0.35': ('setct-CredRevReqTBS', ),
+ '2.23.42.0.36': ('setct-CredRevReqTBSX', ),
+ '2.23.42.0.37': ('setct-CredRevResData', ),
+ '2.23.42.0.38': ('setct-PCertReqData', ),
+ '2.23.42.0.39': ('setct-PCertResTBS', ),
+ '2.23.42.0.40': ('setct-BatchAdminReqData', ),
+ '2.23.42.0.41': ('setct-BatchAdminResData', ),
+ '2.23.42.0.42': ('setct-CardCInitResTBS', ),
+ '2.23.42.0.43': ('setct-MeAqCInitResTBS', ),
+ '2.23.42.0.44': ('setct-RegFormResTBS', ),
+ '2.23.42.0.45': ('setct-CertReqData', ),
+ '2.23.42.0.46': ('setct-CertReqTBS', ),
+ '2.23.42.0.47': ('setct-CertResData', ),
+ '2.23.42.0.48': ('setct-CertInqReqTBS', ),
+ '2.23.42.0.49': ('setct-ErrorTBS', ),
+ '2.23.42.0.50': ('setct-PIDualSignedTBE', ),
+ '2.23.42.0.51': ('setct-PIUnsignedTBE', ),
+ '2.23.42.0.52': ('setct-AuthReqTBE', ),
+ '2.23.42.0.53': ('setct-AuthResTBE', ),
+ '2.23.42.0.54': ('setct-AuthResTBEX', ),
+ '2.23.42.0.55': ('setct-AuthTokenTBE', ),
+ '2.23.42.0.56': ('setct-CapTokenTBE', ),
+ '2.23.42.0.57': ('setct-CapTokenTBEX', ),
+ '2.23.42.0.58': ('setct-AcqCardCodeMsgTBE', ),
+ '2.23.42.0.59': ('setct-AuthRevReqTBE', ),
+ '2.23.42.0.60': ('setct-AuthRevResTBE', ),
+ '2.23.42.0.61': ('setct-AuthRevResTBEB', ),
+ '2.23.42.0.62': ('setct-CapReqTBE', ),
+ '2.23.42.0.63': ('setct-CapReqTBEX', ),
+ '2.23.42.0.64': ('setct-CapResTBE', ),
+ '2.23.42.0.65': ('setct-CapRevReqTBE', ),
+ '2.23.42.0.66': ('setct-CapRevReqTBEX', ),
+ '2.23.42.0.67': ('setct-CapRevResTBE', ),
+ '2.23.42.0.68': ('setct-CredReqTBE', ),
+ '2.23.42.0.69': ('setct-CredReqTBEX', ),
+ '2.23.42.0.70': ('setct-CredResTBE', ),
+ '2.23.42.0.71': ('setct-CredRevReqTBE', ),
+ '2.23.42.0.72': ('setct-CredRevReqTBEX', ),
+ '2.23.42.0.73': ('setct-CredRevResTBE', ),
+ '2.23.42.0.74': ('setct-BatchAdminReqTBE', ),
+ '2.23.42.0.75': ('setct-BatchAdminResTBE', ),
+ '2.23.42.0.76': ('setct-RegFormReqTBE', ),
+ '2.23.42.0.77': ('setct-CertReqTBE', ),
+ '2.23.42.0.78': ('setct-CertReqTBEX', ),
+ '2.23.42.0.79': ('setct-CertResTBE', ),
+ '2.23.42.0.80': ('setct-CRLNotificationTBS', ),
+ '2.23.42.0.81': ('setct-CRLNotificationResTBS', ),
+ '2.23.42.0.82': ('setct-BCIDistributionTBS', ),
+ '2.23.42.1': ('message extensions', 'set-msgExt'),
+ '2.23.42.1.1': ('generic cryptogram', 'setext-genCrypt'),
+ '2.23.42.1.3': ('merchant initiated auth', 'setext-miAuth'),
+ '2.23.42.1.4': ('setext-pinSecure', ),
+ '2.23.42.1.5': ('setext-pinAny', ),
+ '2.23.42.1.7': ('setext-track2', ),
+ '2.23.42.1.8': ('additional verification', 'setext-cv'),
+ '2.23.42.3': ('set-attr', ),
+ '2.23.42.3.0': ('setAttr-Cert', ),
+ '2.23.42.3.0.0': ('set-rootKeyThumb', ),
+ '2.23.42.3.0.1': ('set-addPolicy', ),
+ '2.23.42.3.1': ('payment gateway capabilities', 'setAttr-PGWYcap'),
+ '2.23.42.3.2': ('setAttr-TokenType', ),
+ '2.23.42.3.2.1': ('setAttr-Token-EMV', ),
+ '2.23.42.3.2.2': ('setAttr-Token-B0Prime', ),
+ '2.23.42.3.3': ('issuer capabilities', 'setAttr-IssCap'),
+ '2.23.42.3.3.3': ('setAttr-IssCap-CVM', ),
+ '2.23.42.3.3.3.1': ('generate cryptogram', 'setAttr-GenCryptgrm'),
+ '2.23.42.3.3.4': ('setAttr-IssCap-T2', ),
+ '2.23.42.3.3.4.1': ('encrypted track 2', 'setAttr-T2Enc'),
+ '2.23.42.3.3.4.2': ('cleartext track 2', 'setAttr-T2cleartxt'),
+ '2.23.42.3.3.5': ('setAttr-IssCap-Sig', ),
+ '2.23.42.3.3.5.1': ('ICC or token signature', 'setAttr-TokICCsig'),
+ '2.23.42.3.3.5.2': ('secure device signature', 'setAttr-SecDevSig'),
+ '2.23.42.5': ('set-policy', ),
+ '2.23.42.5.0': ('set-policy-root', ),
+ '2.23.42.7': ('certificate extensions', 'set-certExt'),
+ '2.23.42.7.0': ('setCext-hashedRoot', ),
+ '2.23.42.7.1': ('setCext-certType', ),
+ '2.23.42.7.2': ('setCext-merchData', ),
+ '2.23.42.7.3': ('setCext-cCertRequired', ),
+ '2.23.42.7.4': ('setCext-tunneling', ),
+ '2.23.42.7.5': ('setCext-setExt', ),
+ '2.23.42.7.6': ('setCext-setQualf', ),
+ '2.23.42.7.7': ('setCext-PGWYcapabilities', ),
+ '2.23.42.7.8': ('setCext-TokenIdentifier', ),
+ '2.23.42.7.9': ('setCext-Track2Data', ),
+ '2.23.42.7.10': ('setCext-TokenType', ),
+ '2.23.42.7.11': ('setCext-IssuerCapabilities', ),
+ '2.23.42.8': ('set-brand', ),
+ '2.23.42.8.1': ('set-brand-IATA-ATA', ),
+ '2.23.42.8.4': ('set-brand-Visa', ),
+ '2.23.42.8.5': ('set-brand-MasterCard', ),
+ '2.23.42.8.30': ('set-brand-Diners', ),
+ '2.23.42.8.34': ('set-brand-AmericanExpress', ),
+ '2.23.42.8.35': ('set-brand-JCB', ),
+ '2.23.42.8.6011': ('set-brand-Novus', ),
+ '2.23.43': ('wap', ),
+ '2.23.43.1': ('wap-wsg', ),
+ '2.23.43.1.4': ('wap-wsg-idm-ecid', ),
+ '2.23.43.1.4.1': ('wap-wsg-idm-ecid-wtls1', ),
+ '2.23.43.1.4.3': ('wap-wsg-idm-ecid-wtls3', ),
+ '2.23.43.1.4.4': ('wap-wsg-idm-ecid-wtls4', ),
+ '2.23.43.1.4.5': ('wap-wsg-idm-ecid-wtls5', ),
+ '2.23.43.1.4.6': ('wap-wsg-idm-ecid-wtls6', ),
+ '2.23.43.1.4.7': ('wap-wsg-idm-ecid-wtls7', ),
+ '2.23.43.1.4.8': ('wap-wsg-idm-ecid-wtls8', ),
+ '2.23.43.1.4.9': ('wap-wsg-idm-ecid-wtls9', ),
+ '2.23.43.1.4.10': ('wap-wsg-idm-ecid-wtls10', ),
+ '2.23.43.1.4.11': ('wap-wsg-idm-ecid-wtls11', ),
+ '2.23.43.1.4.12': ('wap-wsg-idm-ecid-wtls12', ),
+}
+# #####################################################################################
+# #####################################################################################
+
+_OID_LOOKUP = dict()
+_NORMALIZE_NAMES = dict()
+_NORMALIZE_NAMES_SHORT = dict()
+
+for dotted, names in _OID_MAP.items():
+ for name in names:
+ if name in _NORMALIZE_NAMES and _OID_LOOKUP[name] != dotted:
+ raise AssertionError(
+ 'Name collision during setup: "{0}" for OIDs {1} and {2}'
+ .format(name, dotted, _OID_LOOKUP[name])
+ )
+ _NORMALIZE_NAMES[name] = names[0]
+ _NORMALIZE_NAMES_SHORT[name] = names[-1]
+ _OID_LOOKUP[name] = dotted
+for alias, original in [('userID', 'userId')]:
+ if alias in _NORMALIZE_NAMES:
+ raise AssertionError(
+ 'Name collision during adding aliases: "{0}" (alias for "{1}") is already mapped to OID {2}'
+ .format(alias, original, _OID_LOOKUP[alias])
+ )
+ _NORMALIZE_NAMES[alias] = original
+ _NORMALIZE_NAMES_SHORT[alias] = _NORMALIZE_NAMES_SHORT[original]
+ _OID_LOOKUP[alias] = _OID_LOOKUP[original]
+
+
+def pyopenssl_normalize_name(name, short=False):
+ nid = OpenSSL._util.lib.OBJ_txt2nid(to_bytes(name))
+ if nid != 0:
+ b_name = OpenSSL._util.lib.OBJ_nid2ln(nid)
+ name = to_text(OpenSSL._util.ffi.string(b_name))
+ if short:
+ return _NORMALIZE_NAMES_SHORT.get(name, name)
+ else:
+ return _NORMALIZE_NAMES.get(name, name)
+
+
+# #####################################################################################
+# #####################################################################################
+# # This excerpt is dual licensed under the terms of the Apache License, Version
+# # 2.0, and the BSD License. See the LICENSE file at
+# # https://github.com/pyca/cryptography/blob/master/LICENSE for complete details.
+# #
+# # Adapted from cryptography's hazmat/backends/openssl/decode_asn1.py
+# #
+# # Copyright (c) 2015, 2016 Paul Kehrer (@reaperhulk)
+# # Copyright (c) 2017 Fraser Tweedale (@frasertweedale)
+# #
+# # Relevant commits from cryptography project (https://github.com/pyca/cryptography):
+# # pyca/cryptography@719d536dd691e84e208534798f2eb4f82aaa2e07
+# # pyca/cryptography@5ab6d6a5c05572bd1c75f05baf264a2d0001894a
+# # pyca/cryptography@2e776e20eb60378e0af9b7439000d0e80da7c7e3
+# # pyca/cryptography@fb309ed24647d1be9e319b61b1f2aa8ebb87b90b
+# # pyca/cryptography@2917e460993c475c72d7146c50dc3bbc2414280d
+# # pyca/cryptography@3057f91ea9a05fb593825006d87a391286a4d828
+# # pyca/cryptography@d607dd7e5bc5c08854ec0c9baff70ba4a35be36f
+def _obj2txt(openssl_lib, openssl_ffi, obj):
+ # Set to 80 on the recommendation of
+ # https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values
+ #
+ # But OIDs longer than this occur in real life (e.g. Active
+ # Directory makes some very long OIDs). So we need to detect
+ # and properly handle the case where the default buffer is not
+ # big enough.
+ #
+ buf_len = 80
+ buf = openssl_ffi.new("char[]", buf_len)
+
+ # 'res' is the number of bytes that *would* be written if the
+ # buffer is large enough. If 'res' > buf_len - 1, we need to
+ # alloc a big-enough buffer and go again.
+ res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1)
+ if res > buf_len - 1: # account for terminating null byte
+ buf_len = res + 1
+ buf = openssl_ffi.new("char[]", buf_len)
+ res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1)
+ return openssl_ffi.buffer(buf, res)[:].decode()
+# #####################################################################################
+# #####################################################################################
+
+
+def cryptography_get_extensions_from_cert(cert):
+ # Since cryptography won't give us the DER value for an extension
+ # (that is only stored for unrecognized extensions), we have to re-do
+ # the extension parsing outselves.
+ result = dict()
+ backend = cert._backend
+ x509_obj = cert._x509
+
+ for i in range(backend._lib.X509_get_ext_count(x509_obj)):
+ ext = backend._lib.X509_get_ext(x509_obj, i)
+ if ext == backend._ffi.NULL:
+ continue
+ crit = backend._lib.X509_EXTENSION_get_critical(ext)
+ data = backend._lib.X509_EXTENSION_get_data(ext)
+ backend.openssl_assert(data != backend._ffi.NULL)
+ der = backend._ffi.buffer(data.data, data.length)[:]
+ entry = dict(
+ critical=(crit == 1),
+ value=base64.b64encode(der),
+ )
+ oid = _obj2txt(backend._lib, backend._ffi, backend._lib.X509_EXTENSION_get_object(ext))
+ result[oid] = entry
+ return result
+
+
+def cryptography_get_extensions_from_csr(csr):
+ # Since cryptography won't give us the DER value for an extension
+ # (that is only stored for unrecognized extensions), we have to re-do
+ # the extension parsing outselves.
+ result = dict()
+ backend = csr._backend
+
+ extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req)
+ extensions = backend._ffi.gc(
+ extensions,
+ lambda ext: backend._lib.sk_X509_EXTENSION_pop_free(
+ ext,
+ backend._ffi.addressof(backend._lib._original_lib, "X509_EXTENSION_free")
+ )
+ )
+
+ for i in range(backend._lib.sk_X509_EXTENSION_num(extensions)):
+ ext = backend._lib.sk_X509_EXTENSION_value(extensions, i)
+ if ext == backend._ffi.NULL:
+ continue
+ crit = backend._lib.X509_EXTENSION_get_critical(ext)
+ data = backend._lib.X509_EXTENSION_get_data(ext)
+ backend.openssl_assert(data != backend._ffi.NULL)
+ der = backend._ffi.buffer(data.data, data.length)[:]
+ entry = dict(
+ critical=(crit == 1),
+ value=base64.b64encode(der),
+ )
+ oid = _obj2txt(backend._lib, backend._ffi, backend._lib.X509_EXTENSION_get_object(ext))
+ result[oid] = entry
+ return result
+
+
+def pyopenssl_get_extensions_from_cert(cert):
+ # While pyOpenSSL allows us to get an extension's DER value, it won't
+ # give us the dotted string for an OID. So we have to do some magic to
+ # get hold of it.
+ result = dict()
+ ext_count = cert.get_extension_count()
+ for i in range(0, ext_count):
+ ext = cert.get_extension(i)
+ entry = dict(
+ critical=bool(ext.get_critical()),
+ value=base64.b64encode(ext.get_data()),
+ )
+ oid = _obj2txt(
+ OpenSSL._util.lib,
+ OpenSSL._util.ffi,
+ OpenSSL._util.lib.X509_EXTENSION_get_object(ext._extension)
+ )
+ # This could also be done a bit simpler:
+ #
+ # oid = _obj2txt(OpenSSL._util.lib, OpenSSL._util.ffi, OpenSSL._util.lib.OBJ_nid2obj(ext._nid))
+ #
+ # Unfortunately this gives the wrong result in case the linked OpenSSL
+ # doesn't know the OID. That's why we have to get the OID dotted string
+ # similarly to how cryptography does it.
+ result[oid] = entry
+ return result
+
+
+def pyopenssl_get_extensions_from_csr(csr):
+ # While pyOpenSSL allows us to get an extension's DER value, it won't
+ # give us the dotted string for an OID. So we have to do some magic to
+ # get hold of it.
+ result = dict()
+ for ext in csr.get_extensions():
+ entry = dict(
+ critical=bool(ext.get_critical()),
+ value=base64.b64encode(ext.get_data()),
+ )
+ oid = _obj2txt(
+ OpenSSL._util.lib,
+ OpenSSL._util.ffi,
+ OpenSSL._util.lib.X509_EXTENSION_get_object(ext._extension)
+ )
+ # This could also be done a bit simpler:
+ #
+ # oid = _obj2txt(OpenSSL._util.lib, OpenSSL._util.ffi, OpenSSL._util.lib.OBJ_nid2obj(ext._nid))
+ #
+ # Unfortunately this gives the wrong result in case the linked OpenSSL
+ # doesn't know the OID. That's why we have to get the OID dotted string
+ # similarly to how cryptography does it.
+ result[oid] = entry
+ return result
+
+
+def cryptography_name_to_oid(name):
+ dotted = _OID_LOOKUP.get(name)
+ if dotted is None:
+ raise OpenSSLObjectError('Cannot find OID for "{0}"'.format(name))
+ return x509.oid.ObjectIdentifier(dotted)
+
+
+def cryptography_oid_to_name(oid, short=False):
+ dotted_string = oid.dotted_string
+ names = _OID_MAP.get(dotted_string)
+ name = names[0] if names else oid._name
+ if short:
+ return _NORMALIZE_NAMES_SHORT.get(name, name)
+ else:
+ return _NORMALIZE_NAMES.get(name, name)
+
+
+def cryptography_get_name(name):
+ '''
+ Given a name string, returns a cryptography x509.Name object.
+ Raises an OpenSSLObjectError if the name is unknown or cannot be parsed.
+ '''
+ try:
+ if name.startswith('DNS:'):
+ return x509.DNSName(to_text(name[4:]))
+ if name.startswith('IP:'):
+ return x509.IPAddress(ipaddress.ip_address(to_text(name[3:])))
+ if name.startswith('email:'):
+ return x509.RFC822Name(to_text(name[6:]))
+ if name.startswith('URI:'):
+ return x509.UniformResourceIdentifier(to_text(name[4:]))
+ except Exception as e:
+ raise OpenSSLObjectError('Cannot parse Subject Alternative Name "{0}": {1}'.format(name, e))
+ if ':' not in name:
+ raise OpenSSLObjectError('Cannot parse Subject Alternative Name "{0}" (forgot "DNS:" prefix?)'.format(name))
+ raise OpenSSLObjectError('Cannot parse Subject Alternative Name "{0}" (potentially unsupported by cryptography backend)'.format(name))
+
+
+def _get_hex(bytesstr):
+ if bytesstr is None:
+ return bytesstr
+ data = binascii.hexlify(bytesstr)
+ data = to_text(b':'.join(data[i:i + 2] for i in range(0, len(data), 2)))
+ return data
+
+
+def cryptography_decode_name(name):
+ '''
+ Given a cryptography x509.Name object, returns a string.
+ Raises an OpenSSLObjectError if the name is not supported.
+ '''
+ if isinstance(name, x509.DNSName):
+ return 'DNS:{0}'.format(name.value)
+ if isinstance(name, x509.IPAddress):
+ return 'IP:{0}'.format(name.value.compressed)
+ if isinstance(name, x509.RFC822Name):
+ return 'email:{0}'.format(name.value)
+ if isinstance(name, x509.UniformResourceIdentifier):
+ return 'URI:{0}'.format(name.value)
+ if isinstance(name, x509.DirectoryName):
+ # FIXME: test
+ return 'DirName:' + ''.join(['/{0}:{1}'.format(attribute.oid._name, attribute.value) for attribute in name.value])
+ if isinstance(name, x509.RegisteredID):
+ # FIXME: test
+ return 'RegisteredID:{0}'.format(name.value)
+ if isinstance(name, x509.OtherName):
+ # FIXME: test
+ return '{0}:{1}'.format(name.type_id.dotted_string, _get_hex(name.value))
+ raise OpenSSLObjectError('Cannot decode name "{0}"'.format(name))
+
+
+def _cryptography_get_keyusage(usage):
+ '''
+ Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage().
+ Raises an OpenSSLObjectError if the identifier is unknown.
+ '''
+ if usage in ('Digital Signature', 'digitalSignature'):
+ return 'digital_signature'
+ if usage in ('Non Repudiation', 'nonRepudiation'):
+ return 'content_commitment'
+ if usage in ('Key Encipherment', 'keyEncipherment'):
+ return 'key_encipherment'
+ if usage in ('Data Encipherment', 'dataEncipherment'):
+ return 'data_encipherment'
+ if usage in ('Key Agreement', 'keyAgreement'):
+ return 'key_agreement'
+ if usage in ('Certificate Sign', 'keyCertSign'):
+ return 'key_cert_sign'
+ if usage in ('CRL Sign', 'cRLSign'):
+ return 'crl_sign'
+ if usage in ('Encipher Only', 'encipherOnly'):
+ return 'encipher_only'
+ if usage in ('Decipher Only', 'decipherOnly'):
+ return 'decipher_only'
+ raise OpenSSLObjectError('Unknown key usage "{0}"'.format(usage))
+
+
+def cryptography_parse_key_usage_params(usages):
+ '''
+ Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage().
+ Raises an OpenSSLObjectError if an identifier is unknown.
+ '''
+ params = dict(
+ digital_signature=False,
+ content_commitment=False,
+ key_encipherment=False,
+ data_encipherment=False,
+ key_agreement=False,
+ key_cert_sign=False,
+ crl_sign=False,
+ encipher_only=False,
+ decipher_only=False,
+ )
+ for usage in usages:
+ params[_cryptography_get_keyusage(usage)] = True
+ return params
+
+
+def cryptography_get_basic_constraints(constraints):
+ '''
+ Given a list of constraints, returns a tuple (ca, path_length).
+ Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed.
+ '''
+ ca = False
+ path_length = None
+ if constraints:
+ for constraint in constraints:
+ if constraint.startswith('CA:'):
+ if constraint == 'CA:TRUE':
+ ca = True
+ elif constraint == 'CA:FALSE':
+ ca = False
+ else:
+ raise OpenSSLObjectError('Unknown basic constraint value "{0}" for CA'.format(constraint[3:]))
+ elif constraint.startswith('pathlen:'):
+ v = constraint[len('pathlen:'):]
+ try:
+ path_length = int(v)
+ except Exception as e:
+ raise OpenSSLObjectError('Cannot parse path length constraint "{0}" ({1})'.format(v, e))
+ else:
+ raise OpenSSLObjectError('Unknown basic constraint "{0}"'.format(constraint))
+ return ca, path_length
+
+
+def binary_exp_mod(f, e, m):
+ '''Computes f^e mod m in O(log e) multiplications modulo m.'''
+ # Compute len_e = floor(log_2(e))
+ len_e = -1
+ x = e
+ while x > 0:
+ x >>= 1
+ len_e += 1
+ # Compute f**e mod m
+ result = 1
+ for k in range(len_e, -1, -1):
+ result = (result * result) % m
+ if ((e >> k) & 1) != 0:
+ result = (result * f) % m
+ return result
+
+
+def simple_gcd(a, b):
+ '''Compute GCD of its two inputs.'''
+ while b != 0:
+ a, b = b, a % b
+ return a
+
+
+def quick_is_not_prime(n):
+ '''Does some quick checks to see if we can poke a hole into the primality of n.
+
+ A result of `False` does **not** mean that the number is prime; it just means
+ that we couldn't detect quickly whether it is not prime.
+ '''
+ if n <= 2:
+ return True
+ # The constant in the next line is the product of all primes < 200
+ if simple_gcd(n, 7799922041683461553249199106329813876687996789903550945093032474868511536164700810) > 1:
+ return True
+ # TODO: maybe do some iterations of Miller-Rabin to increase confidence
+ # (https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test)
+ return False
+
+
+python_version = (sys.version_info[0], sys.version_info[1])
+if python_version >= (2, 7) or python_version >= (3, 1):
+ # Ansible still supports Python 2.6 on remote nodes
+ def count_bits(no):
+ no = abs(no)
+ if no == 0:
+ return 0
+ return no.bit_length()
+else:
+ # Slow, but works
+ def count_bits(no):
+ no = abs(no)
+ count = 0
+ while no > 0:
+ no >>= 1
+ count += 1
+ return count
+
+
+PEM_START = '-----BEGIN '
+PEM_END = '-----'
+PKCS8_PRIVATEKEY_NAMES = ('PRIVATE KEY', 'ENCRYPTED PRIVATE KEY')
+PKCS1_PRIVATEKEY_SUFFIX = ' PRIVATE KEY'
+
+
+def identify_private_key_format(content):
+ '''Given the contents of a private key file, identifies its format.'''
+ # See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
+ # (PEM_read_bio_PrivateKey)
+ # and https://github.com/openssl/openssl/blob/master/include/openssl/pem.h#L46-L47
+ # (PEM_STRING_PKCS8, PEM_STRING_PKCS8INF)
+ try:
+ lines = content.decode('utf-8').splitlines(False)
+ if lines[0].startswith(PEM_START) and lines[0].endswith(PEM_END) and len(lines[0]) > len(PEM_START) + len(PEM_END):
+ name = lines[0][len(PEM_START):-len(PEM_END)]
+ if name in PKCS8_PRIVATEKEY_NAMES:
+ return 'pkcs8'
+ if len(name) > len(PKCS1_PRIVATEKEY_SUFFIX) and name.endswith(PKCS1_PRIVATEKEY_SUFFIX):
+ return 'pkcs1'
+ return 'unknown-pem'
+ except UnicodeDecodeError:
+ pass
+ return 'raw'
+
+
+def cryptography_key_needs_digest_for_signing(key):
+ '''Tests whether the given private key requires a digest algorithm for signing.
+
+ Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm.
+ '''
+ if CRYPTOGRAPHY_HAS_ED25519 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey):
+ return False
+ if CRYPTOGRAPHY_HAS_ED448 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
+ return False
+ return True
+
+
+def cryptography_compare_public_keys(key1, key2):
+ '''Tests whether two public keys are the same.
+
+ Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers().
+ '''
+ if CRYPTOGRAPHY_HAS_ED25519:
+ a = isinstance(key1, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey)
+ b = isinstance(key2, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey)
+ if a or b:
+ if not a or not b:
+ return False
+ a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
+ b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
+ return a == b
+ if CRYPTOGRAPHY_HAS_ED448:
+ a = isinstance(key1, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey)
+ b = isinstance(key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey)
+ if a or b:
+ if not a or not b:
+ return False
+ a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
+ b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
+ return a == b
+ return key1.public_numbers() == key2.public_numbers()
+
+
+if HAS_CRYPTOGRAPHY:
+ REVOCATION_REASON_MAP = {
+ 'unspecified': x509.ReasonFlags.unspecified,
+ 'key_compromise': x509.ReasonFlags.key_compromise,
+ 'ca_compromise': x509.ReasonFlags.ca_compromise,
+ 'affiliation_changed': x509.ReasonFlags.affiliation_changed,
+ 'superseded': x509.ReasonFlags.superseded,
+ 'cessation_of_operation': x509.ReasonFlags.cessation_of_operation,
+ 'certificate_hold': x509.ReasonFlags.certificate_hold,
+ 'privilege_withdrawn': x509.ReasonFlags.privilege_withdrawn,
+ 'aa_compromise': x509.ReasonFlags.aa_compromise,
+ 'remove_from_crl': x509.ReasonFlags.remove_from_crl,
+ }
+ REVOCATION_REASON_MAP_INVERSE = dict()
+ for k, v in REVOCATION_REASON_MAP.items():
+ REVOCATION_REASON_MAP_INVERSE[v] = k
+
+
+def cryptography_decode_revoked_certificate(cert):
+ result = {
+ 'serial_number': cert.serial_number,
+ 'revocation_date': cert.revocation_date,
+ 'issuer': None,
+ 'issuer_critical': False,
+ 'reason': None,
+ 'reason_critical': False,
+ 'invalidity_date': None,
+ 'invalidity_date_critical': False,
+ }
+ try:
+ ext = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
+ result['issuer'] = list(ext.value)
+ result['issuer_critical'] = ext.critical
+ except x509.ExtensionNotFound:
+ pass
+ try:
+ ext = cert.extensions.get_extension_for_class(x509.CRLReason)
+ result['reason'] = ext.value.reason
+ result['reason_critical'] = ext.critical
+ except x509.ExtensionNotFound:
+ pass
+ try:
+ ext = cert.extensions.get_extension_for_class(x509.InvalidityDate)
+ result['invalidity_date'] = ext.value.invalidity_date
+ result['invalidity_date_critical'] = ext.critical
+ except x509.ExtensionNotFound:
+ pass
+ return result
diff --git a/test/support/integration/plugins/module_utils/database.py b/test/support/integration/plugins/module_utils/database.py
new file mode 100644
index 00000000..014939a2
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/database.py
@@ -0,0 +1,142 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# Copyright (c) 2014, Toshio Kuratomi <tkuratomi@ansible.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+class SQLParseError(Exception):
+ pass
+
+
+class UnclosedQuoteError(SQLParseError):
+ pass
+
+
+# maps a type of identifier to the maximum number of dot levels that are
+# allowed to specify that identifier. For example, a database column can be
+# specified by up to 4 levels: database.schema.table.column
+_PG_IDENTIFIER_TO_DOT_LEVEL = dict(
+ database=1,
+ schema=2,
+ table=3,
+ column=4,
+ role=1,
+ tablespace=1,
+ sequence=3,
+ publication=1,
+)
+_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1, vars=1)
+
+
+def _find_end_quote(identifier, quote_char):
+ accumulate = 0
+ while True:
+ try:
+ quote = identifier.index(quote_char)
+ except ValueError:
+ raise UnclosedQuoteError
+ accumulate = accumulate + quote
+ try:
+ next_char = identifier[quote + 1]
+ except IndexError:
+ return accumulate
+ if next_char == quote_char:
+ try:
+ identifier = identifier[quote + 2:]
+ accumulate = accumulate + 2
+ except IndexError:
+ raise UnclosedQuoteError
+ else:
+ return accumulate
+
+
+def _identifier_parse(identifier, quote_char):
+ if not identifier:
+ raise SQLParseError('Identifier name unspecified or unquoted trailing dot')
+
+ already_quoted = False
+ if identifier.startswith(quote_char):
+ already_quoted = True
+ try:
+ end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1
+ except UnclosedQuoteError:
+ already_quoted = False
+ else:
+ if end_quote < len(identifier) - 1:
+ if identifier[end_quote + 1] == '.':
+ dot = end_quote + 1
+ first_identifier = identifier[:dot]
+ next_identifier = identifier[dot + 1:]
+ further_identifiers = _identifier_parse(next_identifier, quote_char)
+ further_identifiers.insert(0, first_identifier)
+ else:
+ raise SQLParseError('User escaped identifiers must escape extra quotes')
+ else:
+ further_identifiers = [identifier]
+
+ if not already_quoted:
+ try:
+ dot = identifier.index('.')
+ except ValueError:
+ identifier = identifier.replace(quote_char, quote_char * 2)
+ identifier = ''.join((quote_char, identifier, quote_char))
+ further_identifiers = [identifier]
+ else:
+ if dot == 0 or dot >= len(identifier) - 1:
+ identifier = identifier.replace(quote_char, quote_char * 2)
+ identifier = ''.join((quote_char, identifier, quote_char))
+ further_identifiers = [identifier]
+ else:
+ first_identifier = identifier[:dot]
+ next_identifier = identifier[dot + 1:]
+ further_identifiers = _identifier_parse(next_identifier, quote_char)
+ first_identifier = first_identifier.replace(quote_char, quote_char * 2)
+ first_identifier = ''.join((quote_char, first_identifier, quote_char))
+ further_identifiers.insert(0, first_identifier)
+
+ return further_identifiers
+
+
+def pg_quote_identifier(identifier, id_type):
+ identifier_fragments = _identifier_parse(identifier, quote_char='"')
+ if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]:
+ raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]))
+ return '.'.join(identifier_fragments)
+
+
+def mysql_quote_identifier(identifier, id_type):
+ identifier_fragments = _identifier_parse(identifier, quote_char='`')
+ if (len(identifier_fragments) - 1) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]:
+ raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]))
+
+ special_cased_fragments = []
+ for fragment in identifier_fragments:
+ if fragment == '`*`':
+ special_cased_fragments.append('*')
+ else:
+ special_cased_fragments.append(fragment)
+
+ return '.'.join(special_cased_fragments)
diff --git a/test/support/integration/plugins/module_utils/docker/__init__.py b/test/support/integration/plugins/module_utils/docker/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/docker/__init__.py
diff --git a/test/support/integration/plugins/module_utils/docker/common.py b/test/support/integration/plugins/module_utils/docker/common.py
new file mode 100644
index 00000000..03307250
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/docker/common.py
@@ -0,0 +1,1022 @@
+#
+# Copyright 2016 Red Hat | Ansible
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+
+import os
+import platform
+import re
+import sys
+from datetime import timedelta
+from distutils.version import LooseVersion
+
+
+from ansible.module_utils.basic import AnsibleModule, env_fallback, missing_required_lib
+from ansible.module_utils.common._collections_compat import Mapping, Sequence
+from ansible.module_utils.six import string_types
+from ansible.module_utils.six.moves.urllib.parse import urlparse
+from ansible.module_utils.parsing.convert_bool import BOOLEANS_TRUE, BOOLEANS_FALSE
+
+HAS_DOCKER_PY = True
+HAS_DOCKER_PY_2 = False
+HAS_DOCKER_PY_3 = False
+HAS_DOCKER_ERROR = None
+
+try:
+ from requests.exceptions import SSLError
+ from docker import __version__ as docker_version
+ from docker.errors import APIError, NotFound, TLSParameterError
+ from docker.tls import TLSConfig
+ from docker import auth
+
+ if LooseVersion(docker_version) >= LooseVersion('3.0.0'):
+ HAS_DOCKER_PY_3 = True
+ from docker import APIClient as Client
+ elif LooseVersion(docker_version) >= LooseVersion('2.0.0'):
+ HAS_DOCKER_PY_2 = True
+ from docker import APIClient as Client
+ else:
+ from docker import Client
+
+except ImportError as exc:
+ HAS_DOCKER_ERROR = str(exc)
+ HAS_DOCKER_PY = False
+
+
+# The next 2 imports ``docker.models`` and ``docker.ssladapter`` are used
+# to ensure the user does not have both ``docker`` and ``docker-py`` modules
+# installed, as they utilize the same namespace are are incompatible
+try:
+ # docker (Docker SDK for Python >= 2.0.0)
+ import docker.models # noqa: F401
+ HAS_DOCKER_MODELS = True
+except ImportError:
+ HAS_DOCKER_MODELS = False
+
+try:
+ # docker-py (Docker SDK for Python < 2.0.0)
+ import docker.ssladapter # noqa: F401
+ HAS_DOCKER_SSLADAPTER = True
+except ImportError:
+ HAS_DOCKER_SSLADAPTER = False
+
+
+try:
+ from requests.exceptions import RequestException
+except ImportError:
+ # Either docker-py is no longer using requests, or docker-py isn't around either,
+ # or docker-py's dependency requests is missing. In any case, define an exception
+ # class RequestException so that our code doesn't break.
+ class RequestException(Exception):
+ pass
+
+
+DEFAULT_DOCKER_HOST = 'unix://var/run/docker.sock'
+DEFAULT_TLS = False
+DEFAULT_TLS_VERIFY = False
+DEFAULT_TLS_HOSTNAME = 'localhost'
+MIN_DOCKER_VERSION = "1.8.0"
+DEFAULT_TIMEOUT_SECONDS = 60
+
+DOCKER_COMMON_ARGS = dict(
+ docker_host=dict(type='str', default=DEFAULT_DOCKER_HOST, fallback=(env_fallback, ['DOCKER_HOST']), aliases=['docker_url']),
+ tls_hostname=dict(type='str', default=DEFAULT_TLS_HOSTNAME, fallback=(env_fallback, ['DOCKER_TLS_HOSTNAME'])),
+ api_version=dict(type='str', default='auto', fallback=(env_fallback, ['DOCKER_API_VERSION']), aliases=['docker_api_version']),
+ timeout=dict(type='int', default=DEFAULT_TIMEOUT_SECONDS, fallback=(env_fallback, ['DOCKER_TIMEOUT'])),
+ ca_cert=dict(type='path', aliases=['tls_ca_cert', 'cacert_path']),
+ client_cert=dict(type='path', aliases=['tls_client_cert', 'cert_path']),
+ client_key=dict(type='path', aliases=['tls_client_key', 'key_path']),
+ ssl_version=dict(type='str', fallback=(env_fallback, ['DOCKER_SSL_VERSION'])),
+ tls=dict(type='bool', default=DEFAULT_TLS, fallback=(env_fallback, ['DOCKER_TLS'])),
+ validate_certs=dict(type='bool', default=DEFAULT_TLS_VERIFY, fallback=(env_fallback, ['DOCKER_TLS_VERIFY']), aliases=['tls_verify']),
+ debug=dict(type='bool', default=False)
+)
+
+DOCKER_MUTUALLY_EXCLUSIVE = []
+
+DOCKER_REQUIRED_TOGETHER = [
+ ['client_cert', 'client_key']
+]
+
+DEFAULT_DOCKER_REGISTRY = 'https://index.docker.io/v1/'
+EMAIL_REGEX = r'[^@]+@[^@]+\.[^@]+'
+BYTE_SUFFIXES = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
+
+
+if not HAS_DOCKER_PY:
+ docker_version = None
+
+ # No Docker SDK for Python. Create a place holder client to allow
+ # instantiation of AnsibleModule and proper error handing
+ class Client(object): # noqa: F811
+ def __init__(self, **kwargs):
+ pass
+
+ class APIError(Exception): # noqa: F811
+ pass
+
+ class NotFound(Exception): # noqa: F811
+ pass
+
+
+def is_image_name_id(name):
+ """Check whether the given image name is in fact an image ID (hash)."""
+ if re.match('^sha256:[0-9a-fA-F]{64}$', name):
+ return True
+ return False
+
+
+def is_valid_tag(tag, allow_empty=False):
+ """Check whether the given string is a valid docker tag name."""
+ if not tag:
+ return allow_empty
+ # See here ("Extended description") for a definition what tags can be:
+ # https://docs.docker.com/engine/reference/commandline/tag/
+ return bool(re.match('^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$', tag))
+
+
+def sanitize_result(data):
+ """Sanitize data object for return to Ansible.
+
+ When the data object contains types such as docker.types.containers.HostConfig,
+ Ansible will fail when these are returned via exit_json or fail_json.
+ HostConfig is derived from dict, but its constructor requires additional
+ arguments. This function sanitizes data structures by recursively converting
+ everything derived from dict to dict and everything derived from list (and tuple)
+ to a list.
+ """
+ if isinstance(data, dict):
+ return dict((k, sanitize_result(v)) for k, v in data.items())
+ elif isinstance(data, (list, tuple)):
+ return [sanitize_result(v) for v in data]
+ else:
+ return data
+
+
+class DockerBaseClass(object):
+
+ def __init__(self):
+ self.debug = False
+
+ def log(self, msg, pretty_print=False):
+ pass
+ # if self.debug:
+ # log_file = open('docker.log', 'a')
+ # if pretty_print:
+ # log_file.write(json.dumps(msg, sort_keys=True, indent=4, separators=(',', ': ')))
+ # log_file.write(u'\n')
+ # else:
+ # log_file.write(msg + u'\n')
+
+
+def update_tls_hostname(result):
+ if result['tls_hostname'] is None:
+ # get default machine name from the url
+ parsed_url = urlparse(result['docker_host'])
+ if ':' in parsed_url.netloc:
+ result['tls_hostname'] = parsed_url.netloc[:parsed_url.netloc.rindex(':')]
+ else:
+ result['tls_hostname'] = parsed_url
+
+
+def _get_tls_config(fail_function, **kwargs):
+ try:
+ tls_config = TLSConfig(**kwargs)
+ return tls_config
+ except TLSParameterError as exc:
+ fail_function("TLS config error: %s" % exc)
+
+
+def get_connect_params(auth, fail_function):
+ if auth['tls'] or auth['tls_verify']:
+ auth['docker_host'] = auth['docker_host'].replace('tcp://', 'https://')
+
+ if auth['tls_verify'] and auth['cert_path'] and auth['key_path']:
+ # TLS with certs and host verification
+ if auth['cacert_path']:
+ tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']),
+ ca_cert=auth['cacert_path'],
+ verify=True,
+ assert_hostname=auth['tls_hostname'],
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ else:
+ tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']),
+ verify=True,
+ assert_hostname=auth['tls_hostname'],
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ if auth['tls_verify'] and auth['cacert_path']:
+ # TLS with cacert only
+ tls_config = _get_tls_config(ca_cert=auth['cacert_path'],
+ assert_hostname=auth['tls_hostname'],
+ verify=True,
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ if auth['tls_verify']:
+ # TLS with verify and no certs
+ tls_config = _get_tls_config(verify=True,
+ assert_hostname=auth['tls_hostname'],
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ if auth['tls'] and auth['cert_path'] and auth['key_path']:
+ # TLS with certs and no host verification
+ tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']),
+ verify=False,
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ if auth['tls']:
+ # TLS with no certs and not host verification
+ tls_config = _get_tls_config(verify=False,
+ ssl_version=auth['ssl_version'],
+ fail_function=fail_function)
+ return dict(base_url=auth['docker_host'],
+ tls=tls_config,
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+ # No TLS
+ return dict(base_url=auth['docker_host'],
+ version=auth['api_version'],
+ timeout=auth['timeout'])
+
+
+DOCKERPYUPGRADE_SWITCH_TO_DOCKER = "Try `pip uninstall docker-py` followed by `pip install docker`."
+DOCKERPYUPGRADE_UPGRADE_DOCKER = "Use `pip install --upgrade docker` to upgrade."
+DOCKERPYUPGRADE_RECOMMEND_DOCKER = ("Use `pip install --upgrade docker-py` to upgrade. "
+ "Hint: if you do not need Python 2.6 support, try "
+ "`pip uninstall docker-py` instead, followed by `pip install docker`.")
+
+
+class AnsibleDockerClient(Client):
+
+ def __init__(self, argument_spec=None, supports_check_mode=False, mutually_exclusive=None,
+ required_together=None, required_if=None, min_docker_version=MIN_DOCKER_VERSION,
+ min_docker_api_version=None, option_minimal_versions=None,
+ option_minimal_versions_ignore_params=None, fail_results=None):
+
+ # Modules can put information in here which will always be returned
+ # in case client.fail() is called.
+ self.fail_results = fail_results or {}
+
+ merged_arg_spec = dict()
+ merged_arg_spec.update(DOCKER_COMMON_ARGS)
+ if argument_spec:
+ merged_arg_spec.update(argument_spec)
+ self.arg_spec = merged_arg_spec
+
+ mutually_exclusive_params = []
+ mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
+ if mutually_exclusive:
+ mutually_exclusive_params += mutually_exclusive
+
+ required_together_params = []
+ required_together_params += DOCKER_REQUIRED_TOGETHER
+ if required_together:
+ required_together_params += required_together
+
+ self.module = AnsibleModule(
+ argument_spec=merged_arg_spec,
+ supports_check_mode=supports_check_mode,
+ mutually_exclusive=mutually_exclusive_params,
+ required_together=required_together_params,
+ required_if=required_if)
+
+ NEEDS_DOCKER_PY2 = (LooseVersion(min_docker_version) >= LooseVersion('2.0.0'))
+
+ self.docker_py_version = LooseVersion(docker_version)
+
+ if HAS_DOCKER_MODELS and HAS_DOCKER_SSLADAPTER:
+ self.fail("Cannot have both the docker-py and docker python modules (old and new version of Docker "
+ "SDK for Python) installed together as they use the same namespace and cause a corrupt "
+ "installation. Please uninstall both packages, and re-install only the docker-py or docker "
+ "python module (for %s's Python %s). It is recommended to install the docker module if no "
+ "support for Python 2.6 is required. Please note that simply uninstalling one of the modules "
+ "can leave the other module in a broken state." % (platform.node(), sys.executable))
+
+ if not HAS_DOCKER_PY:
+ if NEEDS_DOCKER_PY2:
+ msg = missing_required_lib("Docker SDK for Python: docker")
+ msg = msg + ", for example via `pip install docker`. The error was: %s"
+ else:
+ msg = missing_required_lib("Docker SDK for Python: docker (Python >= 2.7) or docker-py (Python 2.6)")
+ msg = msg + ", for example via `pip install docker` or `pip install docker-py` (Python 2.6). The error was: %s"
+ self.fail(msg % HAS_DOCKER_ERROR)
+
+ if self.docker_py_version < LooseVersion(min_docker_version):
+ msg = "Error: Docker SDK for Python version is %s (%s's Python %s). Minimum version required is %s."
+ if not NEEDS_DOCKER_PY2:
+ # The minimal required version is < 2.0 (and the current version as well).
+ # Advertise docker (instead of docker-py) for non-Python-2.6 users.
+ msg += DOCKERPYUPGRADE_RECOMMEND_DOCKER
+ elif docker_version < LooseVersion('2.0'):
+ msg += DOCKERPYUPGRADE_SWITCH_TO_DOCKER
+ else:
+ msg += DOCKERPYUPGRADE_UPGRADE_DOCKER
+ self.fail(msg % (docker_version, platform.node(), sys.executable, min_docker_version))
+
+ self.debug = self.module.params.get('debug')
+ self.check_mode = self.module.check_mode
+ self._connect_params = get_connect_params(self.auth_params, fail_function=self.fail)
+
+ try:
+ super(AnsibleDockerClient, self).__init__(**self._connect_params)
+ self.docker_api_version_str = self.version()['ApiVersion']
+ except APIError as exc:
+ self.fail("Docker API error: %s" % exc)
+ except Exception as exc:
+ self.fail("Error connecting: %s" % exc)
+
+ self.docker_api_version = LooseVersion(self.docker_api_version_str)
+ if min_docker_api_version is not None:
+ if self.docker_api_version < LooseVersion(min_docker_api_version):
+ self.fail('Docker API version is %s. Minimum version required is %s.' % (self.docker_api_version_str, min_docker_api_version))
+
+ if option_minimal_versions is not None:
+ self._get_minimal_versions(option_minimal_versions, option_minimal_versions_ignore_params)
+
+ def log(self, msg, pretty_print=False):
+ pass
+ # if self.debug:
+ # log_file = open('docker.log', 'a')
+ # if pretty_print:
+ # log_file.write(json.dumps(msg, sort_keys=True, indent=4, separators=(',', ': ')))
+ # log_file.write(u'\n')
+ # else:
+ # log_file.write(msg + u'\n')
+
+ def fail(self, msg, **kwargs):
+ self.fail_results.update(kwargs)
+ self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
+
+ @staticmethod
+ def _get_value(param_name, param_value, env_variable, default_value):
+ if param_value is not None:
+ # take module parameter value
+ if param_value in BOOLEANS_TRUE:
+ return True
+ if param_value in BOOLEANS_FALSE:
+ return False
+ return param_value
+
+ if env_variable is not None:
+ env_value = os.environ.get(env_variable)
+ if env_value is not None:
+ # take the env variable value
+ if param_name == 'cert_path':
+ return os.path.join(env_value, 'cert.pem')
+ if param_name == 'cacert_path':
+ return os.path.join(env_value, 'ca.pem')
+ if param_name == 'key_path':
+ return os.path.join(env_value, 'key.pem')
+ if env_value in BOOLEANS_TRUE:
+ return True
+ if env_value in BOOLEANS_FALSE:
+ return False
+ return env_value
+
+ # take the default
+ return default_value
+
+ @property
+ def auth_params(self):
+ # Get authentication credentials.
+ # Precedence: module parameters-> environment variables-> defaults.
+
+ self.log('Getting credentials')
+
+ params = dict()
+ for key in DOCKER_COMMON_ARGS:
+ params[key] = self.module.params.get(key)
+
+ if self.module.params.get('use_tls'):
+ # support use_tls option in docker_image.py. This will be deprecated.
+ use_tls = self.module.params.get('use_tls')
+ if use_tls == 'encrypt':
+ params['tls'] = True
+ if use_tls == 'verify':
+ params['validate_certs'] = True
+
+ result = dict(
+ docker_host=self._get_value('docker_host', params['docker_host'], 'DOCKER_HOST',
+ DEFAULT_DOCKER_HOST),
+ tls_hostname=self._get_value('tls_hostname', params['tls_hostname'],
+ 'DOCKER_TLS_HOSTNAME', DEFAULT_TLS_HOSTNAME),
+ api_version=self._get_value('api_version', params['api_version'], 'DOCKER_API_VERSION',
+ 'auto'),
+ cacert_path=self._get_value('cacert_path', params['ca_cert'], 'DOCKER_CERT_PATH', None),
+ cert_path=self._get_value('cert_path', params['client_cert'], 'DOCKER_CERT_PATH', None),
+ key_path=self._get_value('key_path', params['client_key'], 'DOCKER_CERT_PATH', None),
+ ssl_version=self._get_value('ssl_version', params['ssl_version'], 'DOCKER_SSL_VERSION', None),
+ tls=self._get_value('tls', params['tls'], 'DOCKER_TLS', DEFAULT_TLS),
+ tls_verify=self._get_value('tls_verfy', params['validate_certs'], 'DOCKER_TLS_VERIFY',
+ DEFAULT_TLS_VERIFY),
+ timeout=self._get_value('timeout', params['timeout'], 'DOCKER_TIMEOUT',
+ DEFAULT_TIMEOUT_SECONDS),
+ )
+
+ update_tls_hostname(result)
+
+ return result
+
+ def _handle_ssl_error(self, error):
+ match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error))
+ if match:
+ self.fail("You asked for verification that Docker daemons certificate's hostname matches %s. "
+ "The actual certificate's hostname is %s. Most likely you need to set DOCKER_TLS_HOSTNAME "
+ "or pass `tls_hostname` with a value of %s. You may also use TLS without verification by "
+ "setting the `tls` parameter to true."
+ % (self.auth_params['tls_hostname'], match.group(1), match.group(1)))
+ self.fail("SSL Exception: %s" % (error))
+
+ def _get_minimal_versions(self, option_minimal_versions, ignore_params=None):
+ self.option_minimal_versions = dict()
+ for option in self.module.argument_spec:
+ if ignore_params is not None:
+ if option in ignore_params:
+ continue
+ self.option_minimal_versions[option] = dict()
+ self.option_minimal_versions.update(option_minimal_versions)
+
+ for option, data in self.option_minimal_versions.items():
+ # Test whether option is supported, and store result
+ support_docker_py = True
+ support_docker_api = True
+ if 'docker_py_version' in data:
+ support_docker_py = self.docker_py_version >= LooseVersion(data['docker_py_version'])
+ if 'docker_api_version' in data:
+ support_docker_api = self.docker_api_version >= LooseVersion(data['docker_api_version'])
+ data['supported'] = support_docker_py and support_docker_api
+ # Fail if option is not supported but used
+ if not data['supported']:
+ # Test whether option is specified
+ if 'detect_usage' in data:
+ used = data['detect_usage'](self)
+ else:
+ used = self.module.params.get(option) is not None
+ if used and 'default' in self.module.argument_spec[option]:
+ used = self.module.params[option] != self.module.argument_spec[option]['default']
+ if used:
+ # If the option is used, compose error message.
+ if 'usage_msg' in data:
+ usg = data['usage_msg']
+ else:
+ usg = 'set %s option' % (option, )
+ if not support_docker_api:
+ msg = 'Docker API version is %s. Minimum version required is %s to %s.'
+ msg = msg % (self.docker_api_version_str, data['docker_api_version'], usg)
+ elif not support_docker_py:
+ msg = "Docker SDK for Python version is %s (%s's Python %s). Minimum version required is %s to %s. "
+ if LooseVersion(data['docker_py_version']) < LooseVersion('2.0.0'):
+ msg += DOCKERPYUPGRADE_RECOMMEND_DOCKER
+ elif self.docker_py_version < LooseVersion('2.0.0'):
+ msg += DOCKERPYUPGRADE_SWITCH_TO_DOCKER
+ else:
+ msg += DOCKERPYUPGRADE_UPGRADE_DOCKER
+ msg = msg % (docker_version, platform.node(), sys.executable, data['docker_py_version'], usg)
+ else:
+ # should not happen
+ msg = 'Cannot %s with your configuration.' % (usg, )
+ self.fail(msg)
+
+ def get_container_by_id(self, container_id):
+ try:
+ self.log("Inspecting container Id %s" % container_id)
+ result = self.inspect_container(container=container_id)
+ self.log("Completed container inspection")
+ return result
+ except NotFound as dummy:
+ return None
+ except Exception as exc:
+ self.fail("Error inspecting container: %s" % exc)
+
+ def get_container(self, name=None):
+ '''
+ Lookup a container and return the inspection results.
+ '''
+ if name is None:
+ return None
+
+ search_name = name
+ if not name.startswith('/'):
+ search_name = '/' + name
+
+ result = None
+ try:
+ for container in self.containers(all=True):
+ self.log("testing container: %s" % (container['Names']))
+ if isinstance(container['Names'], list) and search_name in container['Names']:
+ result = container
+ break
+ if container['Id'].startswith(name):
+ result = container
+ break
+ if container['Id'] == name:
+ result = container
+ break
+ except SSLError as exc:
+ self._handle_ssl_error(exc)
+ except Exception as exc:
+ self.fail("Error retrieving container list: %s" % exc)
+
+ if result is None:
+ return None
+
+ return self.get_container_by_id(result['Id'])
+
+ def get_network(self, name=None, network_id=None):
+ '''
+ Lookup a network and return the inspection results.
+ '''
+ if name is None and network_id is None:
+ return None
+
+ result = None
+
+ if network_id is None:
+ try:
+ for network in self.networks():
+ self.log("testing network: %s" % (network['Name']))
+ if name == network['Name']:
+ result = network
+ break
+ if network['Id'].startswith(name):
+ result = network
+ break
+ except SSLError as exc:
+ self._handle_ssl_error(exc)
+ except Exception as exc:
+ self.fail("Error retrieving network list: %s" % exc)
+
+ if result is not None:
+ network_id = result['Id']
+
+ if network_id is not None:
+ try:
+ self.log("Inspecting network Id %s" % network_id)
+ result = self.inspect_network(network_id)
+ self.log("Completed network inspection")
+ except NotFound as dummy:
+ return None
+ except Exception as exc:
+ self.fail("Error inspecting network: %s" % exc)
+
+ return result
+
+ def find_image(self, name, tag):
+ '''
+ Lookup an image (by name and tag) and return the inspection results.
+ '''
+ if not name:
+ return None
+
+ self.log("Find image %s:%s" % (name, tag))
+ images = self._image_lookup(name, tag)
+ if not images:
+ # In API <= 1.20 seeing 'docker.io/<name>' as the name of images pulled from docker hub
+ registry, repo_name = auth.resolve_repository_name(name)
+ if registry == 'docker.io':
+ # If docker.io is explicitly there in name, the image
+ # isn't found in some cases (#41509)
+ self.log("Check for docker.io image: %s" % repo_name)
+ images = self._image_lookup(repo_name, tag)
+ if not images and repo_name.startswith('library/'):
+ # Sometimes library/xxx images are not found
+ lookup = repo_name[len('library/'):]
+ self.log("Check for docker.io image: %s" % lookup)
+ images = self._image_lookup(lookup, tag)
+ if not images:
+ # Last case: if docker.io wasn't there, it can be that
+ # the image wasn't found either (#15586)
+ lookup = "%s/%s" % (registry, repo_name)
+ self.log("Check for docker.io image: %s" % lookup)
+ images = self._image_lookup(lookup, tag)
+
+ if len(images) > 1:
+ self.fail("Registry returned more than one result for %s:%s" % (name, tag))
+
+ if len(images) == 1:
+ try:
+ inspection = self.inspect_image(images[0]['Id'])
+ except Exception as exc:
+ self.fail("Error inspecting image %s:%s - %s" % (name, tag, str(exc)))
+ return inspection
+
+ self.log("Image %s:%s not found." % (name, tag))
+ return None
+
+ def find_image_by_id(self, image_id):
+ '''
+ Lookup an image (by ID) and return the inspection results.
+ '''
+ if not image_id:
+ return None
+
+ self.log("Find image %s (by ID)" % image_id)
+ try:
+ inspection = self.inspect_image(image_id)
+ except Exception as exc:
+ self.fail("Error inspecting image ID %s - %s" % (image_id, str(exc)))
+ return inspection
+
+ def _image_lookup(self, name, tag):
+ '''
+ Including a tag in the name parameter sent to the Docker SDK for Python images method
+ does not work consistently. Instead, get the result set for name and manually check
+ if the tag exists.
+ '''
+ try:
+ response = self.images(name=name)
+ except Exception as exc:
+ self.fail("Error searching for image %s - %s" % (name, str(exc)))
+ images = response
+ if tag:
+ lookup = "%s:%s" % (name, tag)
+ lookup_digest = "%s@%s" % (name, tag)
+ images = []
+ for image in response:
+ tags = image.get('RepoTags')
+ digests = image.get('RepoDigests')
+ if (tags and lookup in tags) or (digests and lookup_digest in digests):
+ images = [image]
+ break
+ return images
+
+ def pull_image(self, name, tag="latest"):
+ '''
+ Pull an image
+ '''
+ self.log("Pulling image %s:%s" % (name, tag))
+ old_tag = self.find_image(name, tag)
+ try:
+ for line in self.pull(name, tag=tag, stream=True, decode=True):
+ self.log(line, pretty_print=True)
+ if line.get('error'):
+ if line.get('errorDetail'):
+ error_detail = line.get('errorDetail')
+ self.fail("Error pulling %s - code: %s message: %s" % (name,
+ error_detail.get('code'),
+ error_detail.get('message')))
+ else:
+ self.fail("Error pulling %s - %s" % (name, line.get('error')))
+ except Exception as exc:
+ self.fail("Error pulling image %s:%s - %s" % (name, tag, str(exc)))
+
+ new_tag = self.find_image(name, tag)
+
+ return new_tag, old_tag == new_tag
+
+ def report_warnings(self, result, warnings_key=None):
+ '''
+ Checks result of client operation for warnings, and if present, outputs them.
+
+ warnings_key should be a list of keys used to crawl the result dictionary.
+ For example, if warnings_key == ['a', 'b'], the function will consider
+ result['a']['b'] if these keys exist. If the result is a non-empty string, it
+ will be reported as a warning. If the result is a list, every entry will be
+ reported as a warning.
+
+ In most cases (if warnings are returned at all), warnings_key should be
+ ['Warnings'] or ['Warning']. The default value (if not specified) is ['Warnings'].
+ '''
+ if warnings_key is None:
+ warnings_key = ['Warnings']
+ for key in warnings_key:
+ if not isinstance(result, Mapping):
+ return
+ result = result.get(key)
+ if isinstance(result, Sequence):
+ for warning in result:
+ self.module.warn('Docker warning: {0}'.format(warning))
+ elif isinstance(result, string_types) and result:
+ self.module.warn('Docker warning: {0}'.format(result))
+
+ def inspect_distribution(self, image, **kwargs):
+ '''
+ Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0
+ since prior versions did not support accessing private repositories.
+ '''
+ if self.docker_py_version < LooseVersion('4.0.0'):
+ registry = auth.resolve_repository_name(image)[0]
+ header = auth.get_config_header(self, registry)
+ if header:
+ return self._result(self._get(
+ self._url('/distribution/{0}/json', image),
+ headers={'X-Registry-Auth': header}
+ ), json=True)
+ return super(AnsibleDockerClient, self).inspect_distribution(image, **kwargs)
+
+
+def compare_dict_allow_more_present(av, bv):
+ '''
+ Compare two dictionaries for whether every entry of the first is in the second.
+ '''
+ for key, value in av.items():
+ if key not in bv:
+ return False
+ if bv[key] != value:
+ return False
+ return True
+
+
+def compare_generic(a, b, method, datatype):
+ '''
+ Compare values a and b as described by method and datatype.
+
+ Returns ``True`` if the values compare equal, and ``False`` if not.
+
+ ``a`` is usually the module's parameter, while ``b`` is a property
+ of the current object. ``a`` must not be ``None`` (except for
+ ``datatype == 'value'``).
+
+ Valid values for ``method`` are:
+ - ``ignore`` (always compare as equal);
+ - ``strict`` (only compare if really equal)
+ - ``allow_more_present`` (allow b to have elements which a does not have).
+
+ Valid values for ``datatype`` are:
+ - ``value``: for simple values (strings, numbers, ...);
+ - ``list``: for ``list``s or ``tuple``s where order matters;
+ - ``set``: for ``list``s, ``tuple``s or ``set``s where order does not
+ matter;
+ - ``set(dict)``: for ``list``s, ``tuple``s or ``sets`` where order does
+ not matter and which contain ``dict``s; ``allow_more_present`` is used
+ for the ``dict``s, and these are assumed to be dictionaries of values;
+ - ``dict``: for dictionaries of values.
+ '''
+ if method == 'ignore':
+ return True
+ # If a or b is None:
+ if a is None or b is None:
+ # If both are None: equality
+ if a == b:
+ return True
+ # Otherwise, not equal for values, and equal
+ # if the other is empty for set/list/dict
+ if datatype == 'value':
+ return False
+ # For allow_more_present, allow a to be None
+ if method == 'allow_more_present' and a is None:
+ return True
+ # Otherwise, the iterable object which is not None must have length 0
+ return len(b if a is None else a) == 0
+ # Do proper comparison (both objects not None)
+ if datatype == 'value':
+ return a == b
+ elif datatype == 'list':
+ if method == 'strict':
+ return a == b
+ else:
+ i = 0
+ for v in a:
+ while i < len(b) and b[i] != v:
+ i += 1
+ if i == len(b):
+ return False
+ i += 1
+ return True
+ elif datatype == 'dict':
+ if method == 'strict':
+ return a == b
+ else:
+ return compare_dict_allow_more_present(a, b)
+ elif datatype == 'set':
+ set_a = set(a)
+ set_b = set(b)
+ if method == 'strict':
+ return set_a == set_b
+ else:
+ return set_b >= set_a
+ elif datatype == 'set(dict)':
+ for av in a:
+ found = False
+ for bv in b:
+ if compare_dict_allow_more_present(av, bv):
+ found = True
+ break
+ if not found:
+ return False
+ if method == 'strict':
+ # If we would know that both a and b do not contain duplicates,
+ # we could simply compare len(a) to len(b) to finish this test.
+ # We can assume that b has no duplicates (as it is returned by
+ # docker), but we don't know for a.
+ for bv in b:
+ found = False
+ for av in a:
+ if compare_dict_allow_more_present(av, bv):
+ found = True
+ break
+ if not found:
+ return False
+ return True
+
+
+class DifferenceTracker(object):
+ def __init__(self):
+ self._diff = []
+
+ def add(self, name, parameter=None, active=None):
+ self._diff.append(dict(
+ name=name,
+ parameter=parameter,
+ active=active,
+ ))
+
+ def merge(self, other_tracker):
+ self._diff.extend(other_tracker._diff)
+
+ @property
+ def empty(self):
+ return len(self._diff) == 0
+
+ def get_before_after(self):
+ '''
+ Return texts ``before`` and ``after``.
+ '''
+ before = dict()
+ after = dict()
+ for item in self._diff:
+ before[item['name']] = item['active']
+ after[item['name']] = item['parameter']
+ return before, after
+
+ def has_difference_for(self, name):
+ '''
+ Returns a boolean if a difference exists for name
+ '''
+ return any(diff for diff in self._diff if diff['name'] == name)
+
+ def get_legacy_docker_container_diffs(self):
+ '''
+ Return differences in the docker_container legacy format.
+ '''
+ result = []
+ for entry in self._diff:
+ item = dict()
+ item[entry['name']] = dict(
+ parameter=entry['parameter'],
+ container=entry['active'],
+ )
+ result.append(item)
+ return result
+
+ def get_legacy_docker_diffs(self):
+ '''
+ Return differences in the docker_container legacy format.
+ '''
+ result = [entry['name'] for entry in self._diff]
+ return result
+
+
+def clean_dict_booleans_for_docker_api(data):
+ '''
+ Go doesn't like Python booleans 'True' or 'False', while Ansible is just
+ fine with them in YAML. As such, they need to be converted in cases where
+ we pass dictionaries to the Docker API (e.g. docker_network's
+ driver_options and docker_prune's filters).
+ '''
+ result = dict()
+ if data is not None:
+ for k, v in data.items():
+ if v is True:
+ v = 'true'
+ elif v is False:
+ v = 'false'
+ else:
+ v = str(v)
+ result[str(k)] = v
+ return result
+
+
+def convert_duration_to_nanosecond(time_str):
+ """
+ Return time duration in nanosecond.
+ """
+ if not isinstance(time_str, str):
+ raise ValueError('Missing unit in duration - %s' % time_str)
+
+ regex = re.compile(
+ r'^(((?P<hours>\d+)h)?'
+ r'((?P<minutes>\d+)m(?!s))?'
+ r'((?P<seconds>\d+)s)?'
+ r'((?P<milliseconds>\d+)ms)?'
+ r'((?P<microseconds>\d+)us)?)$'
+ )
+ parts = regex.match(time_str)
+
+ if not parts:
+ raise ValueError('Invalid time duration - %s' % time_str)
+
+ parts = parts.groupdict()
+ time_params = {}
+ for (name, value) in parts.items():
+ if value:
+ time_params[name] = int(value)
+
+ delta = timedelta(**time_params)
+ time_in_nanoseconds = (
+ delta.microseconds + (delta.seconds + delta.days * 24 * 3600) * 10 ** 6
+ ) * 10 ** 3
+
+ return time_in_nanoseconds
+
+
+def parse_healthcheck(healthcheck):
+ """
+ Return dictionary of healthcheck parameters and boolean if
+ healthcheck defined in image was requested to be disabled.
+ """
+ if (not healthcheck) or (not healthcheck.get('test')):
+ return None, None
+
+ result = dict()
+
+ # All supported healthcheck parameters
+ options = dict(
+ test='test',
+ interval='interval',
+ timeout='timeout',
+ start_period='start_period',
+ retries='retries'
+ )
+
+ duration_options = ['interval', 'timeout', 'start_period']
+
+ for (key, value) in options.items():
+ if value in healthcheck:
+ if healthcheck.get(value) is None:
+ # due to recursive argument_spec, all keys are always present
+ # (but have default value None if not specified)
+ continue
+ if value in duration_options:
+ time = convert_duration_to_nanosecond(healthcheck.get(value))
+ if time:
+ result[key] = time
+ elif healthcheck.get(value):
+ result[key] = healthcheck.get(value)
+ if key == 'test':
+ if isinstance(result[key], (tuple, list)):
+ result[key] = [str(e) for e in result[key]]
+ else:
+ result[key] = ['CMD-SHELL', str(result[key])]
+ elif key == 'retries':
+ try:
+ result[key] = int(result[key])
+ except ValueError:
+ raise ValueError(
+ 'Cannot parse number of retries for healthcheck. '
+ 'Expected an integer, got "{0}".'.format(result[key])
+ )
+
+ if result['test'] == ['NONE']:
+ # If the user explicitly disables the healthcheck, return None
+ # as the healthcheck object, and set disable_healthcheck to True
+ return None, True
+
+ return result, False
+
+
+def omit_none_from_dict(d):
+ """
+ Return a copy of the dictionary with all keys with value None omitted.
+ """
+ return dict((k, v) for (k, v) in d.items() if v is not None)
diff --git a/test/support/integration/plugins/module_utils/docker/swarm.py b/test/support/integration/plugins/module_utils/docker/swarm.py
new file mode 100644
index 00000000..55d94db0
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/docker/swarm.py
@@ -0,0 +1,280 @@
+# (c) 2019 Piotr Wojciechowski (@wojciechowskipiotr) <piotr@it-playground.pl>
+# (c) Thierry Bouvet (@tbouvet)
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+
+import json
+from time import sleep
+
+try:
+ from docker.errors import APIError, NotFound
+except ImportError:
+ # missing Docker SDK for Python handled in ansible.module_utils.docker.common
+ pass
+
+from ansible.module_utils._text import to_native
+from ansible.module_utils.docker.common import (
+ AnsibleDockerClient,
+ LooseVersion,
+)
+
+
+class AnsibleDockerSwarmClient(AnsibleDockerClient):
+
+ def __init__(self, **kwargs):
+ super(AnsibleDockerSwarmClient, self).__init__(**kwargs)
+
+ def get_swarm_node_id(self):
+ """
+ Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID
+ of Docker host the module is executed on
+ :return:
+ NodeID of host or 'None' if not part of Swarm
+ """
+
+ try:
+ info = self.info()
+ except APIError as exc:
+ self.fail("Failed to get node information for %s" % to_native(exc))
+
+ if info:
+ json_str = json.dumps(info, ensure_ascii=False)
+ swarm_info = json.loads(json_str)
+ if swarm_info['Swarm']['NodeID']:
+ return swarm_info['Swarm']['NodeID']
+ return None
+
+ def check_if_swarm_node(self, node_id=None):
+ """
+ Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host
+ system information looking if specific key in output exists. If 'node_id' is provided then it tries to
+ read node information assuming it is run on Swarm manager. The get_node_inspect() method handles exception if
+ it is not executed on Swarm manager
+
+ :param node_id: Node identifier
+ :return:
+ bool: True if node is part of Swarm, False otherwise
+ """
+
+ if node_id is None:
+ try:
+ info = self.info()
+ except APIError:
+ self.fail("Failed to get host information.")
+
+ if info:
+ json_str = json.dumps(info, ensure_ascii=False)
+ swarm_info = json.loads(json_str)
+ if swarm_info['Swarm']['NodeID']:
+ return True
+ if swarm_info['Swarm']['LocalNodeState'] in ('active', 'pending', 'locked'):
+ return True
+ return False
+ else:
+ try:
+ node_info = self.get_node_inspect(node_id=node_id)
+ except APIError:
+ return
+
+ if node_info['ID'] is not None:
+ return True
+ return False
+
+ def check_if_swarm_manager(self):
+ """
+ Checks if node role is set as Manager in Swarm. The node is the docker host on which module action
+ is performed. The inspect_swarm() will fail if node is not a manager
+
+ :return: True if node is Swarm Manager, False otherwise
+ """
+
+ try:
+ self.inspect_swarm()
+ return True
+ except APIError:
+ return False
+
+ def fail_task_if_not_swarm_manager(self):
+ """
+ If host is not a swarm manager then Ansible task on this host should end with 'failed' state
+ """
+ if not self.check_if_swarm_manager():
+ self.fail("Error running docker swarm module: must run on swarm manager node")
+
+ def check_if_swarm_worker(self):
+ """
+ Checks if node role is set as Worker in Swarm. The node is the docker host on which module action
+ is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node()
+
+ :return: True if node is Swarm Worker, False otherwise
+ """
+
+ if self.check_if_swarm_node() and not self.check_if_swarm_manager():
+ return True
+ return False
+
+ def check_if_swarm_node_is_down(self, node_id=None, repeat_check=1):
+ """
+ Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about
+ node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or
+ host that is not part of Swarm it will fail the playbook
+
+ :param repeat_check: number of check attempts with 5 seconds delay between them, by default check only once
+ :param node_id: node ID or name, if None then method will try to get node_id of host module run on
+ :return:
+ True if node is part of swarm but its state is down, False otherwise
+ """
+
+ if repeat_check < 1:
+ repeat_check = 1
+
+ if node_id is None:
+ node_id = self.get_swarm_node_id()
+
+ for retry in range(0, repeat_check):
+ if retry > 0:
+ sleep(5)
+ node_info = self.get_node_inspect(node_id=node_id)
+ if node_info['Status']['State'] == 'down':
+ return True
+ return False
+
+ def get_node_inspect(self, node_id=None, skip_missing=False):
+ """
+ Returns Swarm node info as in 'docker node inspect' command about single node
+
+ :param skip_missing: if True then function will return None instead of failing the task
+ :param node_id: node ID or name, if None then method will try to get node_id of host module run on
+ :return:
+ Single node information structure
+ """
+
+ if node_id is None:
+ node_id = self.get_swarm_node_id()
+
+ if node_id is None:
+ self.fail("Failed to get node information.")
+
+ try:
+ node_info = self.inspect_node(node_id=node_id)
+ except APIError as exc:
+ if exc.status_code == 503:
+ self.fail("Cannot inspect node: To inspect node execute module on Swarm Manager")
+ if exc.status_code == 404:
+ if skip_missing:
+ return None
+ self.fail("Error while reading from Swarm manager: %s" % to_native(exc))
+ except Exception as exc:
+ self.fail("Error inspecting swarm node: %s" % exc)
+
+ json_str = json.dumps(node_info, ensure_ascii=False)
+ node_info = json.loads(json_str)
+
+ if 'ManagerStatus' in node_info:
+ if node_info['ManagerStatus'].get('Leader'):
+ # This is workaround of bug in Docker when in some cases the Leader IP is 0.0.0.0
+ # Check moby/moby#35437 for details
+ count_colons = node_info['ManagerStatus']['Addr'].count(":")
+ if count_colons == 1:
+ swarm_leader_ip = node_info['ManagerStatus']['Addr'].split(":", 1)[0] or node_info['Status']['Addr']
+ else:
+ swarm_leader_ip = node_info['Status']['Addr']
+ node_info['Status']['Addr'] = swarm_leader_ip
+ return node_info
+
+ def get_all_nodes_inspect(self):
+ """
+ Returns Swarm node info as in 'docker node inspect' command about all registered nodes
+
+ :return:
+ Structure with information about all nodes
+ """
+ try:
+ node_info = self.nodes()
+ except APIError as exc:
+ if exc.status_code == 503:
+ self.fail("Cannot inspect node: To inspect node execute module on Swarm Manager")
+ self.fail("Error while reading from Swarm manager: %s" % to_native(exc))
+ except Exception as exc:
+ self.fail("Error inspecting swarm node: %s" % exc)
+
+ json_str = json.dumps(node_info, ensure_ascii=False)
+ node_info = json.loads(json_str)
+ return node_info
+
+ def get_all_nodes_list(self, output='short'):
+ """
+ Returns list of nodes registered in Swarm
+
+ :param output: Defines format of returned data
+ :return:
+ If 'output' is 'short' then return data is list of nodes hostnames registered in Swarm,
+ if 'output' is 'long' then returns data is list of dict containing the attributes as in
+ output of command 'docker node ls'
+ """
+ nodes_list = []
+
+ nodes_inspect = self.get_all_nodes_inspect()
+ if nodes_inspect is None:
+ return None
+
+ if output == 'short':
+ for node in nodes_inspect:
+ nodes_list.append(node['Description']['Hostname'])
+ elif output == 'long':
+ for node in nodes_inspect:
+ node_property = {}
+
+ node_property.update({'ID': node['ID']})
+ node_property.update({'Hostname': node['Description']['Hostname']})
+ node_property.update({'Status': node['Status']['State']})
+ node_property.update({'Availability': node['Spec']['Availability']})
+ if 'ManagerStatus' in node:
+ if node['ManagerStatus']['Leader'] is True:
+ node_property.update({'Leader': True})
+ node_property.update({'ManagerStatus': node['ManagerStatus']['Reachability']})
+ node_property.update({'EngineVersion': node['Description']['Engine']['EngineVersion']})
+
+ nodes_list.append(node_property)
+ else:
+ return None
+
+ return nodes_list
+
+ def get_node_name_by_id(self, nodeid):
+ return self.get_node_inspect(nodeid)['Description']['Hostname']
+
+ def get_unlock_key(self):
+ if self.docker_py_version < LooseVersion('2.7.0'):
+ return None
+ return super(AnsibleDockerSwarmClient, self).get_unlock_key()
+
+ def get_service_inspect(self, service_id, skip_missing=False):
+ """
+ Returns Swarm service info as in 'docker service inspect' command about single service
+
+ :param service_id: service ID or name
+ :param skip_missing: if True then function will return None instead of failing the task
+ :return:
+ Single service information structure
+ """
+ try:
+ service_info = self.inspect_service(service_id)
+ except NotFound as exc:
+ if skip_missing is False:
+ self.fail("Error while reading from Swarm manager: %s" % to_native(exc))
+ else:
+ return None
+ except APIError as exc:
+ if exc.status_code == 503:
+ self.fail("Cannot inspect service: To inspect service execute module on Swarm Manager")
+ self.fail("Error inspecting swarm service: %s" % exc)
+ except Exception as exc:
+ self.fail("Error inspecting swarm service: %s" % exc)
+
+ json_str = json.dumps(service_info, ensure_ascii=False)
+ service_info = json.loads(json_str)
+ return service_info
diff --git a/test/support/integration/plugins/module_utils/ec2.py b/test/support/integration/plugins/module_utils/ec2.py
new file mode 100644
index 00000000..0d28108d
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/ec2.py
@@ -0,0 +1,758 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import os
+import re
+import sys
+import traceback
+
+from ansible.module_utils.ansible_release import __version__
+from ansible.module_utils.basic import missing_required_lib, env_fallback
+from ansible.module_utils._text import to_native, to_text
+from ansible.module_utils.cloud import CloudRetry
+from ansible.module_utils.six import string_types, binary_type, text_type
+from ansible.module_utils.common.dict_transformations import (
+ camel_dict_to_snake_dict, snake_dict_to_camel_dict,
+ _camel_to_snake, _snake_to_camel,
+)
+
+BOTO_IMP_ERR = None
+try:
+ import boto
+ import boto.ec2 # boto does weird import stuff
+ HAS_BOTO = True
+except ImportError:
+ BOTO_IMP_ERR = traceback.format_exc()
+ HAS_BOTO = False
+
+BOTO3_IMP_ERR = None
+try:
+ import boto3
+ import botocore
+ HAS_BOTO3 = True
+except Exception:
+ BOTO3_IMP_ERR = traceback.format_exc()
+ HAS_BOTO3 = False
+
+try:
+ # Although this is to allow Python 3 the ability to use the custom comparison as a key, Python 2.7 also
+ # uses this (and it works as expected). Python 2.6 will trigger the ImportError.
+ from functools import cmp_to_key
+ PY3_COMPARISON = True
+except ImportError:
+ PY3_COMPARISON = False
+
+
+class AnsibleAWSError(Exception):
+ pass
+
+
+def _botocore_exception_maybe():
+ """
+ Allow for boto3 not being installed when using these utils by wrapping
+ botocore.exceptions instead of assigning from it directly.
+ """
+ if HAS_BOTO3:
+ return botocore.exceptions.ClientError
+ return type(None)
+
+
+class AWSRetry(CloudRetry):
+ base_class = _botocore_exception_maybe()
+
+ @staticmethod
+ def status_code_from_exception(error):
+ return error.response['Error']['Code']
+
+ @staticmethod
+ def found(response_code, catch_extra_error_codes=None):
+ # This list of failures is based on this API Reference
+ # http://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html
+ #
+ # TooManyRequestsException comes from inside botocore when it
+ # does retrys, unfortunately however it does not try long
+ # enough to allow some services such as API Gateway to
+ # complete configuration. At the moment of writing there is a
+ # botocore/boto3 bug open to fix this.
+ #
+ # https://github.com/boto/boto3/issues/876 (and linked PRs etc)
+ retry_on = [
+ 'RequestLimitExceeded', 'Unavailable', 'ServiceUnavailable',
+ 'InternalFailure', 'InternalError', 'TooManyRequestsException',
+ 'Throttling'
+ ]
+ if catch_extra_error_codes:
+ retry_on.extend(catch_extra_error_codes)
+
+ return response_code in retry_on
+
+
+def boto3_conn(module, conn_type=None, resource=None, region=None, endpoint=None, **params):
+ try:
+ return _boto3_conn(conn_type=conn_type, resource=resource, region=region, endpoint=endpoint, **params)
+ except ValueError as e:
+ module.fail_json(msg="Couldn't connect to AWS: %s" % to_native(e))
+ except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError,
+ botocore.exceptions.NoCredentialsError, botocore.exceptions.ConfigParseError) as e:
+ module.fail_json(msg=to_native(e))
+ except botocore.exceptions.NoRegionError as e:
+ module.fail_json(msg="The %s module requires a region and none was found in configuration, "
+ "environment variables or module parameters" % module._name)
+
+
+def _boto3_conn(conn_type=None, resource=None, region=None, endpoint=None, **params):
+ profile = params.pop('profile_name', None)
+
+ if conn_type not in ['both', 'resource', 'client']:
+ raise ValueError('There is an issue in the calling code. You '
+ 'must specify either both, resource, or client to '
+ 'the conn_type parameter in the boto3_conn function '
+ 'call')
+
+ config = botocore.config.Config(
+ user_agent_extra='Ansible/{0}'.format(__version__),
+ )
+
+ if params.get('config') is not None:
+ config = config.merge(params.pop('config'))
+ if params.get('aws_config') is not None:
+ config = config.merge(params.pop('aws_config'))
+
+ session = boto3.session.Session(
+ profile_name=profile,
+ )
+
+ if conn_type == 'resource':
+ return session.resource(resource, config=config, region_name=region, endpoint_url=endpoint, **params)
+ elif conn_type == 'client':
+ return session.client(resource, config=config, region_name=region, endpoint_url=endpoint, **params)
+ else:
+ client = session.client(resource, region_name=region, endpoint_url=endpoint, **params)
+ resource = session.resource(resource, region_name=region, endpoint_url=endpoint, **params)
+ return client, resource
+
+
+boto3_inventory_conn = _boto3_conn
+
+
+def boto_exception(err):
+ """
+ Extracts the error message from a boto exception.
+
+ :param err: Exception from boto
+ :return: Error message
+ """
+ if hasattr(err, 'error_message'):
+ error = err.error_message
+ elif hasattr(err, 'message'):
+ error = str(err.message) + ' ' + str(err) + ' - ' + str(type(err))
+ else:
+ error = '%s: %s' % (Exception, err)
+
+ return error
+
+
+def aws_common_argument_spec():
+ return dict(
+ debug_botocore_endpoint_logs=dict(fallback=(env_fallback, ['ANSIBLE_DEBUG_BOTOCORE_LOGS']), default=False, type='bool'),
+ ec2_url=dict(),
+ aws_secret_key=dict(aliases=['ec2_secret_key', 'secret_key'], no_log=True),
+ aws_access_key=dict(aliases=['ec2_access_key', 'access_key']),
+ validate_certs=dict(default=True, type='bool'),
+ security_token=dict(aliases=['access_token'], no_log=True),
+ profile=dict(),
+ aws_config=dict(type='dict'),
+ )
+
+
+def ec2_argument_spec():
+ spec = aws_common_argument_spec()
+ spec.update(
+ dict(
+ region=dict(aliases=['aws_region', 'ec2_region']),
+ )
+ )
+ return spec
+
+
+def get_aws_region(module, boto3=False):
+ region = module.params.get('region')
+
+ if region:
+ return region
+
+ if 'AWS_REGION' in os.environ:
+ return os.environ['AWS_REGION']
+ if 'AWS_DEFAULT_REGION' in os.environ:
+ return os.environ['AWS_DEFAULT_REGION']
+ if 'EC2_REGION' in os.environ:
+ return os.environ['EC2_REGION']
+
+ if not boto3:
+ if not HAS_BOTO:
+ module.fail_json(msg=missing_required_lib('boto'), exception=BOTO_IMP_ERR)
+ # boto.config.get returns None if config not found
+ region = boto.config.get('Boto', 'aws_region')
+ if region:
+ return region
+ return boto.config.get('Boto', 'ec2_region')
+
+ if not HAS_BOTO3:
+ module.fail_json(msg=missing_required_lib('boto3'), exception=BOTO3_IMP_ERR)
+
+ # here we don't need to make an additional call, will default to 'us-east-1' if the below evaluates to None.
+ try:
+ profile_name = module.params.get('profile')
+ return botocore.session.Session(profile=profile_name).get_config_variable('region')
+ except botocore.exceptions.ProfileNotFound as e:
+ return None
+
+
+def get_aws_connection_info(module, boto3=False):
+
+ # Check module args for credentials, then check environment vars
+ # access_key
+
+ ec2_url = module.params.get('ec2_url')
+ access_key = module.params.get('aws_access_key')
+ secret_key = module.params.get('aws_secret_key')
+ security_token = module.params.get('security_token')
+ region = get_aws_region(module, boto3)
+ profile_name = module.params.get('profile')
+ validate_certs = module.params.get('validate_certs')
+ config = module.params.get('aws_config')
+
+ if not ec2_url:
+ if 'AWS_URL' in os.environ:
+ ec2_url = os.environ['AWS_URL']
+ elif 'EC2_URL' in os.environ:
+ ec2_url = os.environ['EC2_URL']
+
+ if not access_key:
+ if os.environ.get('AWS_ACCESS_KEY_ID'):
+ access_key = os.environ['AWS_ACCESS_KEY_ID']
+ elif os.environ.get('AWS_ACCESS_KEY'):
+ access_key = os.environ['AWS_ACCESS_KEY']
+ elif os.environ.get('EC2_ACCESS_KEY'):
+ access_key = os.environ['EC2_ACCESS_KEY']
+ elif HAS_BOTO and boto.config.get('Credentials', 'aws_access_key_id'):
+ access_key = boto.config.get('Credentials', 'aws_access_key_id')
+ elif HAS_BOTO and boto.config.get('default', 'aws_access_key_id'):
+ access_key = boto.config.get('default', 'aws_access_key_id')
+ else:
+ # in case access_key came in as empty string
+ access_key = None
+
+ if not secret_key:
+ if os.environ.get('AWS_SECRET_ACCESS_KEY'):
+ secret_key = os.environ['AWS_SECRET_ACCESS_KEY']
+ elif os.environ.get('AWS_SECRET_KEY'):
+ secret_key = os.environ['AWS_SECRET_KEY']
+ elif os.environ.get('EC2_SECRET_KEY'):
+ secret_key = os.environ['EC2_SECRET_KEY']
+ elif HAS_BOTO and boto.config.get('Credentials', 'aws_secret_access_key'):
+ secret_key = boto.config.get('Credentials', 'aws_secret_access_key')
+ elif HAS_BOTO and boto.config.get('default', 'aws_secret_access_key'):
+ secret_key = boto.config.get('default', 'aws_secret_access_key')
+ else:
+ # in case secret_key came in as empty string
+ secret_key = None
+
+ if not security_token:
+ if os.environ.get('AWS_SECURITY_TOKEN'):
+ security_token = os.environ['AWS_SECURITY_TOKEN']
+ elif os.environ.get('AWS_SESSION_TOKEN'):
+ security_token = os.environ['AWS_SESSION_TOKEN']
+ elif os.environ.get('EC2_SECURITY_TOKEN'):
+ security_token = os.environ['EC2_SECURITY_TOKEN']
+ elif HAS_BOTO and boto.config.get('Credentials', 'aws_security_token'):
+ security_token = boto.config.get('Credentials', 'aws_security_token')
+ elif HAS_BOTO and boto.config.get('default', 'aws_security_token'):
+ security_token = boto.config.get('default', 'aws_security_token')
+ else:
+ # in case secret_token came in as empty string
+ security_token = None
+
+ if HAS_BOTO3 and boto3:
+ boto_params = dict(aws_access_key_id=access_key,
+ aws_secret_access_key=secret_key,
+ aws_session_token=security_token)
+ boto_params['verify'] = validate_certs
+
+ if profile_name:
+ boto_params = dict(aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None)
+ boto_params['profile_name'] = profile_name
+
+ else:
+ boto_params = dict(aws_access_key_id=access_key,
+ aws_secret_access_key=secret_key,
+ security_token=security_token)
+
+ # only set profile_name if passed as an argument
+ if profile_name:
+ boto_params['profile_name'] = profile_name
+
+ boto_params['validate_certs'] = validate_certs
+
+ if config is not None:
+ if HAS_BOTO3 and boto3:
+ boto_params['aws_config'] = botocore.config.Config(**config)
+ elif HAS_BOTO and not boto3:
+ if 'user_agent' in config:
+ sys.modules["boto.connection"].UserAgent = config['user_agent']
+
+ for param, value in boto_params.items():
+ if isinstance(value, binary_type):
+ boto_params[param] = text_type(value, 'utf-8', 'strict')
+
+ return region, ec2_url, boto_params
+
+
+def get_ec2_creds(module):
+ ''' for compatibility mode with old modules that don't/can't yet
+ use ec2_connect method '''
+ region, ec2_url, boto_params = get_aws_connection_info(module)
+ return ec2_url, boto_params['aws_access_key_id'], boto_params['aws_secret_access_key'], region
+
+
+def boto_fix_security_token_in_profile(conn, profile_name):
+ ''' monkey patch for boto issue boto/boto#2100 '''
+ profile = 'profile ' + profile_name
+ if boto.config.has_option(profile, 'aws_security_token'):
+ conn.provider.set_security_token(boto.config.get(profile, 'aws_security_token'))
+ return conn
+
+
+def connect_to_aws(aws_module, region, **params):
+ try:
+ conn = aws_module.connect_to_region(region, **params)
+ except(boto.provider.ProfileNotFoundError):
+ raise AnsibleAWSError("Profile given for AWS was not found. Please fix and retry.")
+ if not conn:
+ if region not in [aws_module_region.name for aws_module_region in aws_module.regions()]:
+ raise AnsibleAWSError("Region %s does not seem to be available for aws module %s. If the region definitely exists, you may need to upgrade "
+ "boto or extend with endpoints_path" % (region, aws_module.__name__))
+ else:
+ raise AnsibleAWSError("Unknown problem connecting to region %s for aws module %s." % (region, aws_module.__name__))
+ if params.get('profile_name'):
+ conn = boto_fix_security_token_in_profile(conn, params['profile_name'])
+ return conn
+
+
+def ec2_connect(module):
+
+ """ Return an ec2 connection"""
+
+ region, ec2_url, boto_params = get_aws_connection_info(module)
+
+ # If we have a region specified, connect to its endpoint.
+ if region:
+ try:
+ ec2 = connect_to_aws(boto.ec2, region, **boto_params)
+ except (boto.exception.NoAuthHandlerFound, AnsibleAWSError, boto.provider.ProfileNotFoundError) as e:
+ module.fail_json(msg=str(e))
+ # Otherwise, no region so we fallback to the old connection method
+ elif ec2_url:
+ try:
+ ec2 = boto.connect_ec2_endpoint(ec2_url, **boto_params)
+ except (boto.exception.NoAuthHandlerFound, AnsibleAWSError, boto.provider.ProfileNotFoundError) as e:
+ module.fail_json(msg=str(e))
+ else:
+ module.fail_json(msg="Either region or ec2_url must be specified")
+
+ return ec2
+
+
+def ansible_dict_to_boto3_filter_list(filters_dict):
+
+ """ Convert an Ansible dict of filters to list of dicts that boto3 can use
+ Args:
+ filters_dict (dict): Dict of AWS filters.
+ Basic Usage:
+ >>> filters = {'some-aws-id': 'i-01234567'}
+ >>> ansible_dict_to_boto3_filter_list(filters)
+ {
+ 'some-aws-id': 'i-01234567'
+ }
+ Returns:
+ List: List of AWS filters and their values
+ [
+ {
+ 'Name': 'some-aws-id',
+ 'Values': [
+ 'i-01234567',
+ ]
+ }
+ ]
+ """
+
+ filters_list = []
+ for k, v in filters_dict.items():
+ filter_dict = {'Name': k}
+ if isinstance(v, string_types):
+ filter_dict['Values'] = [v]
+ else:
+ filter_dict['Values'] = v
+
+ filters_list.append(filter_dict)
+
+ return filters_list
+
+
+def boto3_tag_list_to_ansible_dict(tags_list, tag_name_key_name=None, tag_value_key_name=None):
+
+ """ Convert a boto3 list of resource tags to a flat dict of key:value pairs
+ Args:
+ tags_list (list): List of dicts representing AWS tags.
+ tag_name_key_name (str): Value to use as the key for all tag keys (useful because boto3 doesn't always use "Key")
+ tag_value_key_name (str): Value to use as the key for all tag values (useful because boto3 doesn't always use "Value")
+ Basic Usage:
+ >>> tags_list = [{'Key': 'MyTagKey', 'Value': 'MyTagValue'}]
+ >>> boto3_tag_list_to_ansible_dict(tags_list)
+ [
+ {
+ 'Key': 'MyTagKey',
+ 'Value': 'MyTagValue'
+ }
+ ]
+ Returns:
+ Dict: Dict of key:value pairs representing AWS tags
+ {
+ 'MyTagKey': 'MyTagValue',
+ }
+ """
+
+ if tag_name_key_name and tag_value_key_name:
+ tag_candidates = {tag_name_key_name: tag_value_key_name}
+ else:
+ tag_candidates = {'key': 'value', 'Key': 'Value'}
+
+ if not tags_list:
+ return {}
+ for k, v in tag_candidates.items():
+ if k in tags_list[0] and v in tags_list[0]:
+ return dict((tag[k], tag[v]) for tag in tags_list)
+ raise ValueError("Couldn't find tag key (candidates %s) in tag list %s" % (str(tag_candidates), str(tags_list)))
+
+
+def ansible_dict_to_boto3_tag_list(tags_dict, tag_name_key_name='Key', tag_value_key_name='Value'):
+
+ """ Convert a flat dict of key:value pairs representing AWS resource tags to a boto3 list of dicts
+ Args:
+ tags_dict (dict): Dict representing AWS resource tags.
+ tag_name_key_name (str): Value to use as the key for all tag keys (useful because boto3 doesn't always use "Key")
+ tag_value_key_name (str): Value to use as the key for all tag values (useful because boto3 doesn't always use "Value")
+ Basic Usage:
+ >>> tags_dict = {'MyTagKey': 'MyTagValue'}
+ >>> ansible_dict_to_boto3_tag_list(tags_dict)
+ {
+ 'MyTagKey': 'MyTagValue'
+ }
+ Returns:
+ List: List of dicts containing tag keys and values
+ [
+ {
+ 'Key': 'MyTagKey',
+ 'Value': 'MyTagValue'
+ }
+ ]
+ """
+
+ tags_list = []
+ for k, v in tags_dict.items():
+ tags_list.append({tag_name_key_name: k, tag_value_key_name: to_native(v)})
+
+ return tags_list
+
+
+def get_ec2_security_group_ids_from_names(sec_group_list, ec2_connection, vpc_id=None, boto3=True):
+
+ """ Return list of security group IDs from security group names. Note that security group names are not unique
+ across VPCs. If a name exists across multiple VPCs and no VPC ID is supplied, all matching IDs will be returned. This
+ will probably lead to a boto exception if you attempt to assign both IDs to a resource so ensure you wrap the call in
+ a try block
+ """
+
+ def get_sg_name(sg, boto3):
+
+ if boto3:
+ return sg['GroupName']
+ else:
+ return sg.name
+
+ def get_sg_id(sg, boto3):
+
+ if boto3:
+ return sg['GroupId']
+ else:
+ return sg.id
+
+ sec_group_id_list = []
+
+ if isinstance(sec_group_list, string_types):
+ sec_group_list = [sec_group_list]
+
+ # Get all security groups
+ if boto3:
+ if vpc_id:
+ filters = [
+ {
+ 'Name': 'vpc-id',
+ 'Values': [
+ vpc_id,
+ ]
+ }
+ ]
+ all_sec_groups = ec2_connection.describe_security_groups(Filters=filters)['SecurityGroups']
+ else:
+ all_sec_groups = ec2_connection.describe_security_groups()['SecurityGroups']
+ else:
+ if vpc_id:
+ filters = {'vpc-id': vpc_id}
+ all_sec_groups = ec2_connection.get_all_security_groups(filters=filters)
+ else:
+ all_sec_groups = ec2_connection.get_all_security_groups()
+
+ unmatched = set(sec_group_list).difference(str(get_sg_name(all_sg, boto3)) for all_sg in all_sec_groups)
+ sec_group_name_list = list(set(sec_group_list) - set(unmatched))
+
+ if len(unmatched) > 0:
+ # If we have unmatched names that look like an ID, assume they are
+ import re
+ sec_group_id_list = [sg for sg in unmatched if re.match('sg-[a-fA-F0-9]+$', sg)]
+ still_unmatched = [sg for sg in unmatched if not re.match('sg-[a-fA-F0-9]+$', sg)]
+ if len(still_unmatched) > 0:
+ raise ValueError("The following group names are not valid: %s" % ', '.join(still_unmatched))
+
+ sec_group_id_list += [str(get_sg_id(all_sg, boto3)) for all_sg in all_sec_groups if str(get_sg_name(all_sg, boto3)) in sec_group_name_list]
+
+ return sec_group_id_list
+
+
+def _hashable_policy(policy, policy_list):
+ """
+ Takes a policy and returns a list, the contents of which are all hashable and sorted.
+ Example input policy:
+ {'Version': '2012-10-17',
+ 'Statement': [{'Action': 's3:PutObjectAcl',
+ 'Sid': 'AddCannedAcl2',
+ 'Resource': 'arn:aws:s3:::test_policy/*',
+ 'Effect': 'Allow',
+ 'Principal': {'AWS': ['arn:aws:iam::XXXXXXXXXXXX:user/username1', 'arn:aws:iam::XXXXXXXXXXXX:user/username2']}
+ }]}
+ Returned value:
+ [('Statement', ((('Action', (u's3:PutObjectAcl',)),
+ ('Effect', (u'Allow',)),
+ ('Principal', ('AWS', ((u'arn:aws:iam::XXXXXXXXXXXX:user/username1',), (u'arn:aws:iam::XXXXXXXXXXXX:user/username2',)))),
+ ('Resource', (u'arn:aws:s3:::test_policy/*',)), ('Sid', (u'AddCannedAcl2',)))),
+ ('Version', (u'2012-10-17',)))]
+
+ """
+ # Amazon will automatically convert bool and int to strings for us
+ if isinstance(policy, bool):
+ return tuple([str(policy).lower()])
+ elif isinstance(policy, int):
+ return tuple([str(policy)])
+
+ if isinstance(policy, list):
+ for each in policy:
+ tupleified = _hashable_policy(each, [])
+ if isinstance(tupleified, list):
+ tupleified = tuple(tupleified)
+ policy_list.append(tupleified)
+ elif isinstance(policy, string_types) or isinstance(policy, binary_type):
+ policy = to_text(policy)
+ # convert root account ARNs to just account IDs
+ if policy.startswith('arn:aws:iam::') and policy.endswith(':root'):
+ policy = policy.split(':')[4]
+ return [policy]
+ elif isinstance(policy, dict):
+ sorted_keys = list(policy.keys())
+ sorted_keys.sort()
+ for key in sorted_keys:
+ tupleified = _hashable_policy(policy[key], [])
+ if isinstance(tupleified, list):
+ tupleified = tuple(tupleified)
+ policy_list.append((key, tupleified))
+
+ # ensure we aren't returning deeply nested structures of length 1
+ if len(policy_list) == 1 and isinstance(policy_list[0], tuple):
+ policy_list = policy_list[0]
+ if isinstance(policy_list, list):
+ if PY3_COMPARISON:
+ policy_list.sort(key=cmp_to_key(py3cmp))
+ else:
+ policy_list.sort()
+ return policy_list
+
+
+def py3cmp(a, b):
+ """ Python 2 can sort lists of mixed types. Strings < tuples. Without this function this fails on Python 3."""
+ try:
+ if a > b:
+ return 1
+ elif a < b:
+ return -1
+ else:
+ return 0
+ except TypeError as e:
+ # check to see if they're tuple-string
+ # always say strings are less than tuples (to maintain compatibility with python2)
+ str_ind = to_text(e).find('str')
+ tup_ind = to_text(e).find('tuple')
+ if -1 not in (str_ind, tup_ind):
+ if str_ind < tup_ind:
+ return -1
+ elif tup_ind < str_ind:
+ return 1
+ raise
+
+
+def compare_policies(current_policy, new_policy):
+ """ Compares the existing policy and the updated policy
+ Returns True if there is a difference between policies.
+ """
+ return set(_hashable_policy(new_policy, [])) != set(_hashable_policy(current_policy, []))
+
+
+def sort_json_policy_dict(policy_dict):
+
+ """ Sort any lists in an IAM JSON policy so that comparison of two policies with identical values but
+ different orders will return true
+ Args:
+ policy_dict (dict): Dict representing IAM JSON policy.
+ Basic Usage:
+ >>> my_iam_policy = {'Principle': {'AWS':["31","7","14","101"]}
+ >>> sort_json_policy_dict(my_iam_policy)
+ Returns:
+ Dict: Will return a copy of the policy as a Dict but any List will be sorted
+ {
+ 'Principle': {
+ 'AWS': [ '7', '14', '31', '101' ]
+ }
+ }
+ """
+
+ def value_is_list(my_list):
+
+ checked_list = []
+ for item in my_list:
+ if isinstance(item, dict):
+ checked_list.append(sort_json_policy_dict(item))
+ elif isinstance(item, list):
+ checked_list.append(value_is_list(item))
+ else:
+ checked_list.append(item)
+
+ # Sort list. If it's a list of dictionaries, sort by tuple of key-value
+ # pairs, since Python 3 doesn't allow comparisons such as `<` between dictionaries.
+ checked_list.sort(key=lambda x: sorted(x.items()) if isinstance(x, dict) else x)
+ return checked_list
+
+ ordered_policy_dict = {}
+ for key, value in policy_dict.items():
+ if isinstance(value, dict):
+ ordered_policy_dict[key] = sort_json_policy_dict(value)
+ elif isinstance(value, list):
+ ordered_policy_dict[key] = value_is_list(value)
+ else:
+ ordered_policy_dict[key] = value
+
+ return ordered_policy_dict
+
+
+def map_complex_type(complex_type, type_map):
+ """
+ Allows to cast elements within a dictionary to a specific type
+ Example of usage:
+
+ DEPLOYMENT_CONFIGURATION_TYPE_MAP = {
+ 'maximum_percent': 'int',
+ 'minimum_healthy_percent': 'int'
+ }
+
+ deployment_configuration = map_complex_type(module.params['deployment_configuration'],
+ DEPLOYMENT_CONFIGURATION_TYPE_MAP)
+
+ This ensures all keys within the root element are casted and valid integers
+ """
+
+ if complex_type is None:
+ return
+ new_type = type(complex_type)()
+ if isinstance(complex_type, dict):
+ for key in complex_type:
+ if key in type_map:
+ if isinstance(type_map[key], list):
+ new_type[key] = map_complex_type(
+ complex_type[key],
+ type_map[key][0])
+ else:
+ new_type[key] = map_complex_type(
+ complex_type[key],
+ type_map[key])
+ else:
+ return complex_type
+ elif isinstance(complex_type, list):
+ for i in range(len(complex_type)):
+ new_type.append(map_complex_type(
+ complex_type[i],
+ type_map))
+ elif type_map:
+ return globals()['__builtins__'][type_map](complex_type)
+ return new_type
+
+
+def compare_aws_tags(current_tags_dict, new_tags_dict, purge_tags=True):
+ """
+ Compare two dicts of AWS tags. Dicts are expected to of been created using 'boto3_tag_list_to_ansible_dict' helper function.
+ Two dicts are returned - the first is tags to be set, the second is any tags to remove. Since the AWS APIs differ
+ these may not be able to be used out of the box.
+
+ :param current_tags_dict:
+ :param new_tags_dict:
+ :param purge_tags:
+ :return: tag_key_value_pairs_to_set: a dict of key value pairs that need to be set in AWS. If all tags are identical this dict will be empty
+ :return: tag_keys_to_unset: a list of key names (type str) that need to be unset in AWS. If no tags need to be unset this list will be empty
+ """
+
+ tag_key_value_pairs_to_set = {}
+ tag_keys_to_unset = []
+
+ for key in current_tags_dict.keys():
+ if key not in new_tags_dict and purge_tags:
+ tag_keys_to_unset.append(key)
+
+ for key in set(new_tags_dict.keys()) - set(tag_keys_to_unset):
+ if to_text(new_tags_dict[key]) != current_tags_dict.get(key):
+ tag_key_value_pairs_to_set[key] = new_tags_dict[key]
+
+ return tag_key_value_pairs_to_set, tag_keys_to_unset
diff --git a/test/support/integration/plugins/module_utils/ecs/__init__.py b/test/support/integration/plugins/module_utils/ecs/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/ecs/__init__.py
diff --git a/test/support/integration/plugins/module_utils/ecs/api.py b/test/support/integration/plugins/module_utils/ecs/api.py
new file mode 100644
index 00000000..d89b0333
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/ecs/api.py
@@ -0,0 +1,364 @@
+# -*- coding: utf-8 -*-
+
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is licensed under the
+# Modified BSD License. Modules you write using this snippet, which is embedded
+# dynamically by Ansible, still belong to the author of the module, and may assign
+# their own license to the complete work.
+#
+# Copyright (c), Entrust Datacard Corporation, 2019
+# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
+
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+# 1. Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import json
+import os
+import re
+import time
+import traceback
+
+from ansible.module_utils._text import to_text, to_native
+from ansible.module_utils.basic import missing_required_lib
+from ansible.module_utils.six.moves.urllib.parse import urlencode
+from ansible.module_utils.six.moves.urllib.error import HTTPError
+from ansible.module_utils.urls import Request
+
+YAML_IMP_ERR = None
+try:
+ import yaml
+except ImportError:
+ YAML_FOUND = False
+ YAML_IMP_ERR = traceback.format_exc()
+else:
+ YAML_FOUND = True
+
+valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$")
+
+
+def ecs_client_argument_spec():
+ return dict(
+ entrust_api_user=dict(type='str', required=True),
+ entrust_api_key=dict(type='str', required=True, no_log=True),
+ entrust_api_client_cert_path=dict(type='path', required=True),
+ entrust_api_client_cert_key_path=dict(type='path', required=True, no_log=True),
+ entrust_api_specification_path=dict(type='path', default='https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml'),
+ )
+
+
+class SessionConfigurationException(Exception):
+ """ Raised if we cannot configure a session with the API """
+
+ pass
+
+
+class RestOperationException(Exception):
+ """ Encapsulate a REST API error """
+
+ def __init__(self, error):
+ self.status = to_native(error.get("status", None))
+ self.errors = [to_native(err.get("message")) for err in error.get("errors", {})]
+ self.message = to_native(" ".join(self.errors))
+
+
+def generate_docstring(operation_spec):
+ """Generate a docstring for an operation defined in operation_spec (swagger)"""
+ # Description of the operation
+ docs = operation_spec.get("description", "No Description")
+ docs += "\n\n"
+
+ # Parameters of the operation
+ parameters = operation_spec.get("parameters", [])
+ if len(parameters) != 0:
+ docs += "\tArguments:\n\n"
+ for parameter in parameters:
+ docs += "{0} ({1}:{2}): {3}\n".format(
+ parameter.get("name"),
+ parameter.get("type", "No Type"),
+ "Required" if parameter.get("required", False) else "Not Required",
+ parameter.get("description"),
+ )
+
+ return docs
+
+
+def bind(instance, method, operation_spec):
+ def binding_scope_fn(*args, **kwargs):
+ return method(instance, *args, **kwargs)
+
+ # Make sure we don't confuse users; add the proper name and documentation to the function.
+ # Users can use !help(<function>) to get help on the function from interactive python or pdb
+ operation_name = operation_spec.get("operationId").split("Using")[0]
+ binding_scope_fn.__name__ = str(operation_name)
+ binding_scope_fn.__doc__ = generate_docstring(operation_spec)
+
+ return binding_scope_fn
+
+
+class RestOperation(object):
+ def __init__(self, session, uri, method, parameters=None):
+ self.session = session
+ self.method = method
+ if parameters is None:
+ self.parameters = {}
+ else:
+ self.parameters = parameters
+ self.url = "{scheme}://{host}{base_path}{uri}".format(scheme="https", host=session._spec.get("host"), base_path=session._spec.get("basePath"), uri=uri)
+
+ def restmethod(self, *args, **kwargs):
+ """Do the hard work of making the request here"""
+
+ # gather named path parameters and do substitution on the URL
+ if self.parameters:
+ path_parameters = {}
+ body_parameters = {}
+ query_parameters = {}
+ for x in self.parameters:
+ expected_location = x.get("in")
+ key_name = x.get("name", None)
+ key_value = kwargs.get(key_name, None)
+ if expected_location == "path" and key_name and key_value:
+ path_parameters.update({key_name: key_value})
+ elif expected_location == "body" and key_name and key_value:
+ body_parameters.update({key_name: key_value})
+ elif expected_location == "query" and key_name and key_value:
+ query_parameters.update({key_name: key_value})
+
+ if len(body_parameters.keys()) >= 1:
+ body_parameters = body_parameters.get(list(body_parameters.keys())[0])
+ else:
+ body_parameters = None
+ else:
+ path_parameters = {}
+ query_parameters = {}
+ body_parameters = None
+
+ # This will fail if we have not set path parameters with a KeyError
+ url = self.url.format(**path_parameters)
+ if query_parameters:
+ # modify the URL to add path parameters
+ url = url + "?" + urlencode(query_parameters)
+
+ try:
+ if body_parameters:
+ body_parameters_json = json.dumps(body_parameters)
+ response = self.session.request.open(method=self.method, url=url, data=body_parameters_json)
+ else:
+ response = self.session.request.open(method=self.method, url=url)
+ request_error = False
+ except HTTPError as e:
+ # An HTTPError has the same methods available as a valid response from request.open
+ response = e
+ request_error = True
+
+ # Return the result if JSON and success ({} for empty responses)
+ # Raise an exception if there was a failure.
+ try:
+ result_code = response.getcode()
+ result = json.loads(response.read())
+ except ValueError:
+ result = {}
+
+ if result or result == {}:
+ if result_code and result_code < 400:
+ return result
+ else:
+ raise RestOperationException(result)
+
+ # Raise a generic RestOperationException if this fails
+ raise RestOperationException({"status": result_code, "errors": [{"message": "REST Operation Failed"}]})
+
+
+class Resource(object):
+ """ Implement basic CRUD operations against a path. """
+
+ def __init__(self, session):
+ self.session = session
+ self.parameters = {}
+
+ for url in session._spec.get("paths").keys():
+ methods = session._spec.get("paths").get(url)
+ for method in methods.keys():
+ operation_spec = methods.get(method)
+ operation_name = operation_spec.get("operationId", None)
+ parameters = operation_spec.get("parameters")
+
+ if not operation_name:
+ if method.lower() == "post":
+ operation_name = "Create"
+ elif method.lower() == "get":
+ operation_name = "Get"
+ elif method.lower() == "put":
+ operation_name = "Update"
+ elif method.lower() == "delete":
+ operation_name = "Delete"
+ elif method.lower() == "patch":
+ operation_name = "Patch"
+ else:
+ raise SessionConfigurationException(to_native("Invalid REST method type {0}".format(method)))
+
+ # Get the non-parameter parts of the URL and append to the operation name
+ # e.g /application/version -> GetApplicationVersion
+ # e.g. /application/{id} -> GetApplication
+ # This may lead to duplicates, which we must prevent.
+ operation_name += re.sub(r"{(.*)}", "", url).replace("/", " ").title().replace(" ", "")
+ operation_spec["operationId"] = operation_name
+
+ op = RestOperation(session, url, method, parameters)
+ setattr(self, operation_name, bind(self, op.restmethod, operation_spec))
+
+
+# Session to encapsulate the connection parameters of the module_utils Request object, the api spec, etc
+class ECSSession(object):
+ def __init__(self, name, **kwargs):
+ """
+ Initialize our session
+ """
+
+ self._set_config(name, **kwargs)
+
+ def client(self):
+ resource = Resource(self)
+ return resource
+
+ def _set_config(self, name, **kwargs):
+ headers = {
+ "Content-Type": "application/json",
+ "Connection": "keep-alive",
+ }
+ self.request = Request(headers=headers, timeout=60)
+
+ configurators = [self._read_config_vars]
+ for configurator in configurators:
+ self._config = configurator(name, **kwargs)
+ if self._config:
+ break
+ if self._config is None:
+ raise SessionConfigurationException(to_native("No Configuration Found."))
+
+ # set up auth if passed
+ entrust_api_user = self.get_config("entrust_api_user")
+ entrust_api_key = self.get_config("entrust_api_key")
+ if entrust_api_user and entrust_api_key:
+ self.request.url_username = entrust_api_user
+ self.request.url_password = entrust_api_key
+ else:
+ raise SessionConfigurationException(to_native("User and key must be provided."))
+
+ # set up client certificate if passed (support all-in one or cert + key)
+ entrust_api_cert = self.get_config("entrust_api_cert")
+ entrust_api_cert_key = self.get_config("entrust_api_cert_key")
+ if entrust_api_cert:
+ self.request.client_cert = entrust_api_cert
+ if entrust_api_cert_key:
+ self.request.client_key = entrust_api_cert_key
+ else:
+ raise SessionConfigurationException(to_native("Client certificate for authentication to the API must be provided."))
+
+ # set up the spec
+ entrust_api_specification_path = self.get_config("entrust_api_specification_path")
+
+ if not entrust_api_specification_path.startswith("http") and not os.path.isfile(entrust_api_specification_path):
+ raise SessionConfigurationException(to_native("OpenAPI specification was not found at location {0}.".format(entrust_api_specification_path)))
+ if not valid_file_format.match(entrust_api_specification_path):
+ raise SessionConfigurationException(to_native("OpenAPI specification filename must end in .json, .yml or .yaml"))
+
+ self.verify = True
+
+ if entrust_api_specification_path.startswith("http"):
+ try:
+ http_response = Request().open(method="GET", url=entrust_api_specification_path)
+ http_response_contents = http_response.read()
+ if entrust_api_specification_path.endswith(".json"):
+ self._spec = json.load(http_response_contents)
+ elif entrust_api_specification_path.endswith(".yml") or entrust_api_specification_path.endswith(".yaml"):
+ self._spec = yaml.safe_load(http_response_contents)
+ except HTTPError as e:
+ raise SessionConfigurationException(to_native("Error downloading specification from address '{0}', received error code '{1}'".format(
+ entrust_api_specification_path, e.getcode())))
+ else:
+ with open(entrust_api_specification_path) as f:
+ if ".json" in entrust_api_specification_path:
+ self._spec = json.load(f)
+ elif ".yml" in entrust_api_specification_path or ".yaml" in entrust_api_specification_path:
+ self._spec = yaml.safe_load(f)
+
+ def get_config(self, item):
+ return self._config.get(item, None)
+
+ def _read_config_vars(self, name, **kwargs):
+ """ Read configuration from variables passed to the module. """
+ config = {}
+
+ entrust_api_specification_path = kwargs.get("entrust_api_specification_path")
+ if not entrust_api_specification_path or (not entrust_api_specification_path.startswith("http") and not os.path.isfile(entrust_api_specification_path)):
+ raise SessionConfigurationException(
+ to_native(
+ "Parameter provided for entrust_api_specification_path of value '{0}' was not a valid file path or HTTPS address.".format(
+ entrust_api_specification_path
+ )
+ )
+ )
+
+ for required_file in ["entrust_api_cert", "entrust_api_cert_key"]:
+ file_path = kwargs.get(required_file)
+ if not file_path or not os.path.isfile(file_path):
+ raise SessionConfigurationException(
+ to_native("Parameter provided for {0} of value '{1}' was not a valid file path.".format(required_file, file_path))
+ )
+
+ for required_var in ["entrust_api_user", "entrust_api_key"]:
+ if not kwargs.get(required_var):
+ raise SessionConfigurationException(to_native("Parameter provided for {0} was missing.".format(required_var)))
+
+ config["entrust_api_cert"] = kwargs.get("entrust_api_cert")
+ config["entrust_api_cert_key"] = kwargs.get("entrust_api_cert_key")
+ config["entrust_api_specification_path"] = kwargs.get("entrust_api_specification_path")
+ config["entrust_api_user"] = kwargs.get("entrust_api_user")
+ config["entrust_api_key"] = kwargs.get("entrust_api_key")
+
+ return config
+
+
+def ECSClient(entrust_api_user=None, entrust_api_key=None, entrust_api_cert=None, entrust_api_cert_key=None, entrust_api_specification_path=None):
+ """Create an ECS client"""
+
+ if not YAML_FOUND:
+ raise SessionConfigurationException(missing_required_lib("PyYAML"), exception=YAML_IMP_ERR)
+
+ if entrust_api_specification_path is None:
+ entrust_api_specification_path = "https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml"
+
+ # Not functionally necessary with current uses of this module_util, but better to be explicit for future use cases
+ entrust_api_user = to_text(entrust_api_user)
+ entrust_api_key = to_text(entrust_api_key)
+ entrust_api_cert_key = to_text(entrust_api_cert_key)
+ entrust_api_specification_path = to_text(entrust_api_specification_path)
+
+ return ECSSession(
+ "ecs",
+ entrust_api_user=entrust_api_user,
+ entrust_api_key=entrust_api_key,
+ entrust_api_cert=entrust_api_cert,
+ entrust_api_cert_key=entrust_api_cert_key,
+ entrust_api_specification_path=entrust_api_specification_path,
+ ).client()
diff --git a/test/support/integration/plugins/module_utils/mysql.py b/test/support/integration/plugins/module_utils/mysql.py
new file mode 100644
index 00000000..46198f36
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/mysql.py
@@ -0,0 +1,106 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# Copyright (c), Jonathan Mainguy <jon@soh.re>, 2015
+# Most of this was originally added by Sven Schliesing @muffl0n in the mysql_user.py module
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+
+try:
+ import pymysql as mysql_driver
+ _mysql_cursor_param = 'cursor'
+except ImportError:
+ try:
+ import MySQLdb as mysql_driver
+ import MySQLdb.cursors
+ _mysql_cursor_param = 'cursorclass'
+ except ImportError:
+ mysql_driver = None
+
+mysql_driver_fail_msg = 'The PyMySQL (Python 2.7 and Python 3.X) or MySQL-python (Python 2.X) module is required.'
+
+
+def mysql_connect(module, login_user=None, login_password=None, config_file='', ssl_cert=None, ssl_key=None, ssl_ca=None, db=None, cursor_class=None,
+ connect_timeout=30, autocommit=False):
+ config = {}
+
+ if ssl_ca is not None or ssl_key is not None or ssl_cert is not None:
+ config['ssl'] = {}
+
+ if module.params['login_unix_socket']:
+ config['unix_socket'] = module.params['login_unix_socket']
+ else:
+ config['host'] = module.params['login_host']
+ config['port'] = module.params['login_port']
+
+ if os.path.exists(config_file):
+ config['read_default_file'] = config_file
+
+ # If login_user or login_password are given, they should override the
+ # config file
+ if login_user is not None:
+ config['user'] = login_user
+ if login_password is not None:
+ config['passwd'] = login_password
+ if ssl_cert is not None:
+ config['ssl']['cert'] = ssl_cert
+ if ssl_key is not None:
+ config['ssl']['key'] = ssl_key
+ if ssl_ca is not None:
+ config['ssl']['ca'] = ssl_ca
+ if db is not None:
+ config['db'] = db
+ if connect_timeout is not None:
+ config['connect_timeout'] = connect_timeout
+
+ if _mysql_cursor_param == 'cursor':
+ # In case of PyMySQL driver:
+ db_connection = mysql_driver.connect(autocommit=autocommit, **config)
+ else:
+ # In case of MySQLdb driver
+ db_connection = mysql_driver.connect(**config)
+ if autocommit:
+ db_connection.autocommit(True)
+
+ if cursor_class == 'DictCursor':
+ return db_connection.cursor(**{_mysql_cursor_param: mysql_driver.cursors.DictCursor}), db_connection
+ else:
+ return db_connection.cursor(), db_connection
+
+
+def mysql_common_argument_spec():
+ return dict(
+ login_user=dict(type='str', default=None),
+ login_password=dict(type='str', no_log=True),
+ login_host=dict(type='str', default='localhost'),
+ login_port=dict(type='int', default=3306),
+ login_unix_socket=dict(type='str'),
+ config_file=dict(type='path', default='~/.my.cnf'),
+ connect_timeout=dict(type='int', default=30),
+ client_cert=dict(type='path', aliases=['ssl_cert']),
+ client_key=dict(type='path', aliases=['ssl_key']),
+ ca_cert=dict(type='path', aliases=['ssl_ca']),
+ )
diff --git a/test/support/integration/plugins/module_utils/net_tools/__init__.py b/test/support/integration/plugins/module_utils/net_tools/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/net_tools/__init__.py
diff --git a/test/support/integration/plugins/module_utils/network/__init__.py b/test/support/integration/plugins/module_utils/network/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/network/__init__.py
diff --git a/test/support/integration/plugins/module_utils/network/common/__init__.py b/test/support/integration/plugins/module_utils/network/common/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/network/common/__init__.py
diff --git a/test/support/integration/plugins/module_utils/network/common/utils.py b/test/support/integration/plugins/module_utils/network/common/utils.py
new file mode 100644
index 00000000..80317387
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/network/common/utils.py
@@ -0,0 +1,643 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# (c) 2016 Red Hat Inc.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+# Networking tools for network modules only
+
+import re
+import ast
+import operator
+import socket
+import json
+
+from itertools import chain
+
+from ansible.module_utils._text import to_text, to_bytes
+from ansible.module_utils.common._collections_compat import Mapping
+from ansible.module_utils.six import iteritems, string_types
+from ansible.module_utils import basic
+from ansible.module_utils.parsing.convert_bool import boolean
+
+# Backwards compatibility for 3rd party modules
+# TODO(pabelanger): With move to ansible.netcommon, we should clean this code
+# up and have modules import directly themself.
+from ansible.module_utils.common.network import ( # noqa: F401
+ to_bits, is_netmask, is_masklen, to_netmask, to_masklen, to_subnet, to_ipv6_network, VALID_MASKS
+)
+
+try:
+ from jinja2 import Environment, StrictUndefined
+ from jinja2.exceptions import UndefinedError
+ HAS_JINJA2 = True
+except ImportError:
+ HAS_JINJA2 = False
+
+
+OPERATORS = frozenset(['ge', 'gt', 'eq', 'neq', 'lt', 'le'])
+ALIASES = frozenset([('min', 'ge'), ('max', 'le'), ('exactly', 'eq'), ('neq', 'ne')])
+
+
+def to_list(val):
+ if isinstance(val, (list, tuple, set)):
+ return list(val)
+ elif val is not None:
+ return [val]
+ else:
+ return list()
+
+
+def to_lines(stdout):
+ for item in stdout:
+ if isinstance(item, string_types):
+ item = to_text(item).split('\n')
+ yield item
+
+
+def transform_commands(module):
+ transform = ComplexList(dict(
+ command=dict(key=True),
+ output=dict(),
+ prompt=dict(type='list'),
+ answer=dict(type='list'),
+ newline=dict(type='bool', default=True),
+ sendonly=dict(type='bool', default=False),
+ check_all=dict(type='bool', default=False),
+ ), module)
+
+ return transform(module.params['commands'])
+
+
+def sort_list(val):
+ if isinstance(val, list):
+ return sorted(val)
+ return val
+
+
+class Entity(object):
+ """Transforms a dict to with an argument spec
+
+ This class will take a dict and apply an Ansible argument spec to the
+ values. The resulting dict will contain all of the keys in the param
+ with appropriate values set.
+
+ Example::
+
+ argument_spec = dict(
+ command=dict(key=True),
+ display=dict(default='text', choices=['text', 'json']),
+ validate=dict(type='bool')
+ )
+ transform = Entity(module, argument_spec)
+ value = dict(command='foo')
+ result = transform(value)
+ print result
+ {'command': 'foo', 'display': 'text', 'validate': None}
+
+ Supported argument spec:
+ * key - specifies how to map a single value to a dict
+ * read_from - read and apply the argument_spec from the module
+ * required - a value is required
+ * type - type of value (uses AnsibleModule type checker)
+ * fallback - implements fallback function
+ * choices - set of valid options
+ * default - default value
+ """
+
+ def __init__(self, module, attrs=None, args=None, keys=None, from_argspec=False):
+ args = [] if args is None else args
+
+ self._attributes = attrs or {}
+ self._module = module
+
+ for arg in args:
+ self._attributes[arg] = dict()
+ if from_argspec:
+ self._attributes[arg]['read_from'] = arg
+ if keys and arg in keys:
+ self._attributes[arg]['key'] = True
+
+ self.attr_names = frozenset(self._attributes.keys())
+
+ _has_key = False
+
+ for name, attr in iteritems(self._attributes):
+ if attr.get('read_from'):
+ if attr['read_from'] not in self._module.argument_spec:
+ module.fail_json(msg='argument %s does not exist' % attr['read_from'])
+ spec = self._module.argument_spec.get(attr['read_from'])
+ for key, value in iteritems(spec):
+ if key not in attr:
+ attr[key] = value
+
+ if attr.get('key'):
+ if _has_key:
+ module.fail_json(msg='only one key value can be specified')
+ _has_key = True
+ attr['required'] = True
+
+ def serialize(self):
+ return self._attributes
+
+ def to_dict(self, value):
+ obj = {}
+ for name, attr in iteritems(self._attributes):
+ if attr.get('key'):
+ obj[name] = value
+ else:
+ obj[name] = attr.get('default')
+ return obj
+
+ def __call__(self, value, strict=True):
+ if not isinstance(value, dict):
+ value = self.to_dict(value)
+
+ if strict:
+ unknown = set(value).difference(self.attr_names)
+ if unknown:
+ self._module.fail_json(msg='invalid keys: %s' % ','.join(unknown))
+
+ for name, attr in iteritems(self._attributes):
+ if value.get(name) is None:
+ value[name] = attr.get('default')
+
+ if attr.get('fallback') and not value.get(name):
+ fallback = attr.get('fallback', (None,))
+ fallback_strategy = fallback[0]
+ fallback_args = []
+ fallback_kwargs = {}
+ if fallback_strategy is not None:
+ for item in fallback[1:]:
+ if isinstance(item, dict):
+ fallback_kwargs = item
+ else:
+ fallback_args = item
+ try:
+ value[name] = fallback_strategy(*fallback_args, **fallback_kwargs)
+ except basic.AnsibleFallbackNotFound:
+ continue
+
+ if attr.get('required') and value.get(name) is None:
+ self._module.fail_json(msg='missing required attribute %s' % name)
+
+ if 'choices' in attr:
+ if value[name] not in attr['choices']:
+ self._module.fail_json(msg='%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name]))
+
+ if value[name] is not None:
+ value_type = attr.get('type', 'str')
+ type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type]
+ type_checker(value[name])
+ elif value.get(name):
+ value[name] = self._module.params[name]
+
+ return value
+
+
+class EntityCollection(Entity):
+ """Extends ```Entity``` to handle a list of dicts """
+
+ def __call__(self, iterable, strict=True):
+ if iterable is None:
+ iterable = [super(EntityCollection, self).__call__(self._module.params, strict)]
+
+ if not isinstance(iterable, (list, tuple)):
+ self._module.fail_json(msg='value must be an iterable')
+
+ return [(super(EntityCollection, self).__call__(i, strict)) for i in iterable]
+
+
+# these two are for backwards compatibility and can be removed once all of the
+# modules that use them are updated
+class ComplexDict(Entity):
+ def __init__(self, attrs, module, *args, **kwargs):
+ super(ComplexDict, self).__init__(module, attrs, *args, **kwargs)
+
+
+class ComplexList(EntityCollection):
+ def __init__(self, attrs, module, *args, **kwargs):
+ super(ComplexList, self).__init__(module, attrs, *args, **kwargs)
+
+
+def dict_diff(base, comparable):
+ """ Generate a dict object of differences
+
+ This function will compare two dict objects and return the difference
+ between them as a dict object. For scalar values, the key will reflect
+ the updated value. If the key does not exist in `comparable`, then then no
+ key will be returned. For lists, the value in comparable will wholly replace
+ the value in base for the key. For dicts, the returned value will only
+ return keys that are different.
+
+ :param base: dict object to base the diff on
+ :param comparable: dict object to compare against base
+
+ :returns: new dict object with differences
+ """
+ if not isinstance(base, dict):
+ raise AssertionError("`base` must be of type <dict>")
+ if not isinstance(comparable, dict):
+ if comparable is None:
+ comparable = dict()
+ else:
+ raise AssertionError("`comparable` must be of type <dict>")
+
+ updates = dict()
+
+ for key, value in iteritems(base):
+ if isinstance(value, dict):
+ item = comparable.get(key)
+ if item is not None:
+ sub_diff = dict_diff(value, comparable[key])
+ if sub_diff:
+ updates[key] = sub_diff
+ else:
+ comparable_value = comparable.get(key)
+ if comparable_value is not None:
+ if sort_list(base[key]) != sort_list(comparable_value):
+ updates[key] = comparable_value
+
+ for key in set(comparable.keys()).difference(base.keys()):
+ updates[key] = comparable.get(key)
+
+ return updates
+
+
+def dict_merge(base, other):
+ """ Return a new dict object that combines base and other
+
+ This will create a new dict object that is a combination of the key/value
+ pairs from base and other. When both keys exist, the value will be
+ selected from other. If the value is a list object, the two lists will
+ be combined and duplicate entries removed.
+
+ :param base: dict object to serve as base
+ :param other: dict object to combine with base
+
+ :returns: new combined dict object
+ """
+ if not isinstance(base, dict):
+ raise AssertionError("`base` must be of type <dict>")
+ if not isinstance(other, dict):
+ raise AssertionError("`other` must be of type <dict>")
+
+ combined = dict()
+
+ for key, value in iteritems(base):
+ if isinstance(value, dict):
+ if key in other:
+ item = other.get(key)
+ if item is not None:
+ if isinstance(other[key], Mapping):
+ combined[key] = dict_merge(value, other[key])
+ else:
+ combined[key] = other[key]
+ else:
+ combined[key] = item
+ else:
+ combined[key] = value
+ elif isinstance(value, list):
+ if key in other:
+ item = other.get(key)
+ if item is not None:
+ try:
+ combined[key] = list(set(chain(value, item)))
+ except TypeError:
+ value.extend([i for i in item if i not in value])
+ combined[key] = value
+ else:
+ combined[key] = item
+ else:
+ combined[key] = value
+ else:
+ if key in other:
+ other_value = other.get(key)
+ if other_value is not None:
+ if sort_list(base[key]) != sort_list(other_value):
+ combined[key] = other_value
+ else:
+ combined[key] = value
+ else:
+ combined[key] = other_value
+ else:
+ combined[key] = value
+
+ for key in set(other.keys()).difference(base.keys()):
+ combined[key] = other.get(key)
+
+ return combined
+
+
+def param_list_to_dict(param_list, unique_key="name", remove_key=True):
+ """Rotates a list of dictionaries to be a dictionary of dictionaries.
+
+ :param param_list: The aforementioned list of dictionaries
+ :param unique_key: The name of a key which is present and unique in all of param_list's dictionaries. The value
+ behind this key will be the key each dictionary can be found at in the new root dictionary
+ :param remove_key: If True, remove unique_key from the individual dictionaries before returning.
+ """
+ param_dict = {}
+ for params in param_list:
+ params = params.copy()
+ if remove_key:
+ name = params.pop(unique_key)
+ else:
+ name = params.get(unique_key)
+ param_dict[name] = params
+
+ return param_dict
+
+
+def conditional(expr, val, cast=None):
+ match = re.match(r'^(.+)\((.+)\)$', str(expr), re.I)
+ if match:
+ op, arg = match.groups()
+ else:
+ op = 'eq'
+ if ' ' in str(expr):
+ raise AssertionError('invalid expression: cannot contain spaces')
+ arg = expr
+
+ if cast is None and val is not None:
+ arg = type(val)(arg)
+ elif callable(cast):
+ arg = cast(arg)
+ val = cast(val)
+
+ op = next((oper for alias, oper in ALIASES if op == alias), op)
+
+ if not hasattr(operator, op) and op not in OPERATORS:
+ raise ValueError('unknown operator: %s' % op)
+
+ func = getattr(operator, op)
+ return func(val, arg)
+
+
+def ternary(value, true_val, false_val):
+ ''' value ? true_val : false_val '''
+ if value:
+ return true_val
+ else:
+ return false_val
+
+
+def remove_default_spec(spec):
+ for item in spec:
+ if 'default' in spec[item]:
+ del spec[item]['default']
+
+
+def validate_ip_address(address):
+ try:
+ socket.inet_aton(address)
+ except socket.error:
+ return False
+ return address.count('.') == 3
+
+
+def validate_ip_v6_address(address):
+ try:
+ socket.inet_pton(socket.AF_INET6, address)
+ except socket.error:
+ return False
+ return True
+
+
+def validate_prefix(prefix):
+ if prefix and not 0 <= int(prefix) <= 32:
+ return False
+ return True
+
+
+def load_provider(spec, args):
+ provider = args.get('provider') or {}
+ for key, value in iteritems(spec):
+ if key not in provider:
+ if 'fallback' in value:
+ provider[key] = _fallback(value['fallback'])
+ elif 'default' in value:
+ provider[key] = value['default']
+ else:
+ provider[key] = None
+ if 'authorize' in provider:
+ # Coerce authorize to provider if a string has somehow snuck in.
+ provider['authorize'] = boolean(provider['authorize'] or False)
+ args['provider'] = provider
+ return provider
+
+
+def _fallback(fallback):
+ strategy = fallback[0]
+ args = []
+ kwargs = {}
+
+ for item in fallback[1:]:
+ if isinstance(item, dict):
+ kwargs = item
+ else:
+ args = item
+ try:
+ return strategy(*args, **kwargs)
+ except basic.AnsibleFallbackNotFound:
+ pass
+
+
+def generate_dict(spec):
+ """
+ Generate dictionary which is in sync with argspec
+
+ :param spec: A dictionary that is the argspec of the module
+ :rtype: A dictionary
+ :returns: A dictionary in sync with argspec with default value
+ """
+ obj = {}
+ if not spec:
+ return obj
+
+ for key, val in iteritems(spec):
+ if 'default' in val:
+ dct = {key: val['default']}
+ elif 'type' in val and val['type'] == 'dict':
+ dct = {key: generate_dict(val['options'])}
+ else:
+ dct = {key: None}
+ obj.update(dct)
+ return obj
+
+
+def parse_conf_arg(cfg, arg):
+ """
+ Parse config based on argument
+
+ :param cfg: A text string which is a line of configuration.
+ :param arg: A text string which is to be matched.
+ :rtype: A text string
+ :returns: A text string if match is found
+ """
+ match = re.search(r'%s (.+)(\n|$)' % arg, cfg, re.M)
+ if match:
+ result = match.group(1).strip()
+ else:
+ result = None
+ return result
+
+
+def parse_conf_cmd_arg(cfg, cmd, res1, res2=None, delete_str='no'):
+ """
+ Parse config based on command
+
+ :param cfg: A text string which is a line of configuration.
+ :param cmd: A text string which is the command to be matched
+ :param res1: A text string to be returned if the command is present
+ :param res2: A text string to be returned if the negate command
+ is present
+ :param delete_str: A text string to identify the start of the
+ negate command
+ :rtype: A text string
+ :returns: A text string if match is found
+ """
+ match = re.search(r'\n\s+%s(\n|$)' % cmd, cfg)
+ if match:
+ return res1
+ if res2 is not None:
+ match = re.search(r'\n\s+%s %s(\n|$)' % (delete_str, cmd), cfg)
+ if match:
+ return res2
+ return None
+
+
+def get_xml_conf_arg(cfg, path, data='text'):
+ """
+ :param cfg: The top level configuration lxml Element tree object
+ :param path: The relative xpath w.r.t to top level element (cfg)
+ to be searched in the xml hierarchy
+ :param data: The type of data to be returned for the matched xml node.
+ Valid values are text, tag, attrib, with default as text.
+ :return: Returns the required type for the matched xml node or else None
+ """
+ match = cfg.xpath(path)
+ if len(match):
+ if data == 'tag':
+ result = getattr(match[0], 'tag')
+ elif data == 'attrib':
+ result = getattr(match[0], 'attrib')
+ else:
+ result = getattr(match[0], 'text')
+ else:
+ result = None
+ return result
+
+
+def remove_empties(cfg_dict):
+ """
+ Generate final config dictionary
+
+ :param cfg_dict: A dictionary parsed in the facts system
+ :rtype: A dictionary
+ :returns: A dictionary by eliminating keys that have null values
+ """
+ final_cfg = {}
+ if not cfg_dict:
+ return final_cfg
+
+ for key, val in iteritems(cfg_dict):
+ dct = None
+ if isinstance(val, dict):
+ child_val = remove_empties(val)
+ if child_val:
+ dct = {key: child_val}
+ elif (isinstance(val, list) and val
+ and all([isinstance(x, dict) for x in val])):
+ child_val = [remove_empties(x) for x in val]
+ if child_val:
+ dct = {key: child_val}
+ elif val not in [None, [], {}, (), '']:
+ dct = {key: val}
+ if dct:
+ final_cfg.update(dct)
+ return final_cfg
+
+
+def validate_config(spec, data):
+ """
+ Validate if the input data against the AnsibleModule spec format
+ :param spec: Ansible argument spec
+ :param data: Data to be validated
+ :return:
+ """
+ params = basic._ANSIBLE_ARGS
+ basic._ANSIBLE_ARGS = to_bytes(json.dumps({'ANSIBLE_MODULE_ARGS': data}))
+ validated_data = basic.AnsibleModule(spec).params
+ basic._ANSIBLE_ARGS = params
+ return validated_data
+
+
+def search_obj_in_list(name, lst, key='name'):
+ if not lst:
+ return None
+ else:
+ for item in lst:
+ if item.get(key) == name:
+ return item
+
+
+class Template:
+
+ def __init__(self):
+ if not HAS_JINJA2:
+ raise ImportError("jinja2 is required but does not appear to be installed. "
+ "It can be installed using `pip install jinja2`")
+
+ self.env = Environment(undefined=StrictUndefined)
+ self.env.filters.update({'ternary': ternary})
+
+ def __call__(self, value, variables=None, fail_on_undefined=True):
+ variables = variables or {}
+
+ if not self.contains_vars(value):
+ return value
+
+ try:
+ value = self.env.from_string(value).render(variables)
+ except UndefinedError:
+ if not fail_on_undefined:
+ return None
+ raise
+
+ if value:
+ try:
+ return ast.literal_eval(value)
+ except Exception:
+ return str(value)
+ else:
+ return None
+
+ def contains_vars(self, data):
+ if isinstance(data, string_types):
+ for marker in (self.env.block_start_string, self.env.variable_start_string, self.env.comment_start_string):
+ if marker in data:
+ return True
+ return False
diff --git a/test/support/integration/plugins/module_utils/postgres.py b/test/support/integration/plugins/module_utils/postgres.py
new file mode 100644
index 00000000..63811c30
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/postgres.py
@@ -0,0 +1,330 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# Copyright (c), Ted Timmons <ted@timmons.me>, 2017.
+# Most of this was originally added by other creators in the postgresql_user module.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+psycopg2 = None # This line needs for unit tests
+try:
+ import psycopg2
+ HAS_PSYCOPG2 = True
+except ImportError:
+ HAS_PSYCOPG2 = False
+
+from ansible.module_utils.basic import missing_required_lib
+from ansible.module_utils._text import to_native
+from ansible.module_utils.six import iteritems
+from distutils.version import LooseVersion
+
+
+def postgres_common_argument_spec():
+ """
+ Return a dictionary with connection options.
+
+ The options are commonly used by most of PostgreSQL modules.
+ """
+ return dict(
+ login_user=dict(default='postgres'),
+ login_password=dict(default='', no_log=True),
+ login_host=dict(default=''),
+ login_unix_socket=dict(default=''),
+ port=dict(type='int', default=5432, aliases=['login_port']),
+ ssl_mode=dict(default='prefer', choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full']),
+ ca_cert=dict(aliases=['ssl_rootcert']),
+ )
+
+
+def ensure_required_libs(module):
+ """Check required libraries."""
+ if not HAS_PSYCOPG2:
+ module.fail_json(msg=missing_required_lib('psycopg2'))
+
+ if module.params.get('ca_cert') and LooseVersion(psycopg2.__version__) < LooseVersion('2.4.3'):
+ module.fail_json(msg='psycopg2 must be at least 2.4.3 in order to use the ca_cert parameter')
+
+
+def connect_to_db(module, conn_params, autocommit=False, fail_on_conn=True):
+ """Connect to a PostgreSQL database.
+
+ Return psycopg2 connection object.
+
+ Args:
+ module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class
+ conn_params (dict) -- dictionary with connection parameters
+
+ Kwargs:
+ autocommit (bool) -- commit automatically (default False)
+ fail_on_conn (bool) -- fail if connection failed or just warn and return None (default True)
+ """
+ ensure_required_libs(module)
+
+ db_connection = None
+ try:
+ db_connection = psycopg2.connect(**conn_params)
+ if autocommit:
+ if LooseVersion(psycopg2.__version__) >= LooseVersion('2.4.2'):
+ db_connection.set_session(autocommit=True)
+ else:
+ db_connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+
+ # Switch role, if specified:
+ if module.params.get('session_role'):
+ cursor = db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor)
+
+ try:
+ cursor.execute('SET ROLE "%s"' % module.params['session_role'])
+ except Exception as e:
+ module.fail_json(msg="Could not switch role: %s" % to_native(e))
+ finally:
+ cursor.close()
+
+ except TypeError as e:
+ if 'sslrootcert' in e.args[0]:
+ module.fail_json(msg='Postgresql server must be at least '
+ 'version 8.4 to support sslrootcert')
+
+ if fail_on_conn:
+ module.fail_json(msg="unable to connect to database: %s" % to_native(e))
+ else:
+ module.warn("PostgreSQL server is unavailable: %s" % to_native(e))
+ db_connection = None
+
+ except Exception as e:
+ if fail_on_conn:
+ module.fail_json(msg="unable to connect to database: %s" % to_native(e))
+ else:
+ module.warn("PostgreSQL server is unavailable: %s" % to_native(e))
+ db_connection = None
+
+ return db_connection
+
+
+def exec_sql(obj, query, query_params=None, ddl=False, add_to_executed=True, dont_exec=False):
+ """Execute SQL.
+
+ Auxiliary function for PostgreSQL user classes.
+
+ Returns a query result if possible or True/False if ddl=True arg was passed.
+ It necessary for statements that don't return any result (like DDL queries).
+
+ Args:
+ obj (obj) -- must be an object of a user class.
+ The object must have module (AnsibleModule class object) and
+ cursor (psycopg cursor object) attributes
+ query (str) -- SQL query to execute
+
+ Kwargs:
+ query_params (dict or tuple) -- Query parameters to prevent SQL injections,
+ could be a dict or tuple
+ ddl (bool) -- must return True or False instead of rows (typical for DDL queries)
+ (default False)
+ add_to_executed (bool) -- append the query to obj.executed_queries attribute
+ dont_exec (bool) -- used with add_to_executed=True to generate a query, add it
+ to obj.executed_queries list and return True (default False)
+ """
+
+ if dont_exec:
+ # This is usually needed to return queries in check_mode
+ # without execution
+ query = obj.cursor.mogrify(query, query_params)
+ if add_to_executed:
+ obj.executed_queries.append(query)
+
+ return True
+
+ try:
+ if query_params is not None:
+ obj.cursor.execute(query, query_params)
+ else:
+ obj.cursor.execute(query)
+
+ if add_to_executed:
+ if query_params is not None:
+ obj.executed_queries.append(obj.cursor.mogrify(query, query_params))
+ else:
+ obj.executed_queries.append(query)
+
+ if not ddl:
+ res = obj.cursor.fetchall()
+ return res
+ return True
+ except Exception as e:
+ obj.module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e)))
+ return False
+
+
+def get_conn_params(module, params_dict, warn_db_default=True):
+ """Get connection parameters from the passed dictionary.
+
+ Return a dictionary with parameters to connect to PostgreSQL server.
+
+ Args:
+ module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class
+ params_dict (dict) -- dictionary with variables
+
+ Kwargs:
+ warn_db_default (bool) -- warn that the default DB is used (default True)
+ """
+ # To use defaults values, keyword arguments must be absent, so
+ # check which values are empty and don't include in the return dictionary
+ params_map = {
+ "login_host": "host",
+ "login_user": "user",
+ "login_password": "password",
+ "port": "port",
+ "ssl_mode": "sslmode",
+ "ca_cert": "sslrootcert"
+ }
+
+ # Might be different in the modules:
+ if params_dict.get('db'):
+ params_map['db'] = 'database'
+ elif params_dict.get('database'):
+ params_map['database'] = 'database'
+ elif params_dict.get('login_db'):
+ params_map['login_db'] = 'database'
+ else:
+ if warn_db_default:
+ module.warn('Database name has not been passed, '
+ 'used default database to connect to.')
+
+ kw = dict((params_map[k], v) for (k, v) in iteritems(params_dict)
+ if k in params_map and v != '' and v is not None)
+
+ # If a login_unix_socket is specified, incorporate it here.
+ is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost"
+ if is_localhost and params_dict["login_unix_socket"] != "":
+ kw["host"] = params_dict["login_unix_socket"]
+
+ return kw
+
+
+class PgMembership(object):
+ def __init__(self, module, cursor, groups, target_roles, fail_on_role=True):
+ self.module = module
+ self.cursor = cursor
+ self.target_roles = [r.strip() for r in target_roles]
+ self.groups = [r.strip() for r in groups]
+ self.executed_queries = []
+ self.granted = {}
+ self.revoked = {}
+ self.fail_on_role = fail_on_role
+ self.non_existent_roles = []
+ self.changed = False
+ self.__check_roles_exist()
+
+ def grant(self):
+ for group in self.groups:
+ self.granted[group] = []
+
+ for role in self.target_roles:
+ # If role is in a group now, pass:
+ if self.__check_membership(group, role):
+ continue
+
+ query = 'GRANT "%s" TO "%s"' % (group, role)
+ self.changed = exec_sql(self, query, ddl=True)
+
+ if self.changed:
+ self.granted[group].append(role)
+
+ return self.changed
+
+ def revoke(self):
+ for group in self.groups:
+ self.revoked[group] = []
+
+ for role in self.target_roles:
+ # If role is not in a group now, pass:
+ if not self.__check_membership(group, role):
+ continue
+
+ query = 'REVOKE "%s" FROM "%s"' % (group, role)
+ self.changed = exec_sql(self, query, ddl=True)
+
+ if self.changed:
+ self.revoked[group].append(role)
+
+ return self.changed
+
+ def __check_membership(self, src_role, dst_role):
+ query = ("SELECT ARRAY(SELECT b.rolname FROM "
+ "pg_catalog.pg_auth_members m "
+ "JOIN pg_catalog.pg_roles b ON (m.roleid = b.oid) "
+ "WHERE m.member = r.oid) "
+ "FROM pg_catalog.pg_roles r "
+ "WHERE r.rolname = %(dst_role)s")
+
+ res = exec_sql(self, query, query_params={'dst_role': dst_role}, add_to_executed=False)
+ membership = []
+ if res:
+ membership = res[0][0]
+
+ if not membership:
+ return False
+
+ if src_role in membership:
+ return True
+
+ return False
+
+ def __check_roles_exist(self):
+ existent_groups = self.__roles_exist(self.groups)
+ existent_roles = self.__roles_exist(self.target_roles)
+
+ for group in self.groups:
+ if group not in existent_groups:
+ if self.fail_on_role:
+ self.module.fail_json(msg="Role %s does not exist" % group)
+ else:
+ self.module.warn("Role %s does not exist, pass" % group)
+ self.non_existent_roles.append(group)
+
+ for role in self.target_roles:
+ if role not in existent_roles:
+ if self.fail_on_role:
+ self.module.fail_json(msg="Role %s does not exist" % role)
+ else:
+ self.module.warn("Role %s does not exist, pass" % role)
+
+ if role not in self.groups:
+ self.non_existent_roles.append(role)
+
+ else:
+ if self.fail_on_role:
+ self.module.exit_json(msg="Role role '%s' is a member of role '%s'" % (role, role))
+ else:
+ self.module.warn("Role role '%s' is a member of role '%s', pass" % (role, role))
+
+ # Update role lists, excluding non existent roles:
+ self.groups = [g for g in self.groups if g not in self.non_existent_roles]
+
+ self.target_roles = [r for r in self.target_roles if r not in self.non_existent_roles]
+
+ def __roles_exist(self, roles):
+ tmp = ["'" + x + "'" for x in roles]
+ query = "SELECT rolname FROM pg_roles WHERE rolname IN (%s)" % ','.join(tmp)
+ return [x[0] for x in exec_sql(self, query, add_to_executed=False)]
diff --git a/test/support/integration/plugins/module_utils/rabbitmq.py b/test/support/integration/plugins/module_utils/rabbitmq.py
new file mode 100644
index 00000000..cf764006
--- /dev/null
+++ b/test/support/integration/plugins/module_utils/rabbitmq.py
@@ -0,0 +1,220 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright: (c) 2016, Jorge Rodriguez <jorge.rodriguez@tiriel.eu>
+# Copyright: (c) 2018, John Imison <john+github@imison.net>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+from ansible.module_utils._text import to_native
+from ansible.module_utils.basic import missing_required_lib
+from ansible.module_utils.six.moves.urllib import parse as urllib_parse
+from mimetypes import MimeTypes
+
+import os
+import json
+import traceback
+
+PIKA_IMP_ERR = None
+try:
+ import pika
+ import pika.exceptions
+ from pika import spec
+ HAS_PIKA = True
+except ImportError:
+ PIKA_IMP_ERR = traceback.format_exc()
+ HAS_PIKA = False
+
+
+def rabbitmq_argument_spec():
+ return dict(
+ login_user=dict(type='str', default='guest'),
+ login_password=dict(type='str', default='guest', no_log=True),
+ login_host=dict(type='str', default='localhost'),
+ login_port=dict(type='str', default='15672'),
+ login_protocol=dict(type='str', default='http', choices=['http', 'https']),
+ ca_cert=dict(type='path', aliases=['cacert']),
+ client_cert=dict(type='path', aliases=['cert']),
+ client_key=dict(type='path', aliases=['key']),
+ vhost=dict(type='str', default='/'),
+ )
+
+
+# notification/rabbitmq_basic_publish.py
+class RabbitClient():
+ def __init__(self, module):
+ self.module = module
+ self.params = module.params
+ self.check_required_library()
+ self.check_host_params()
+ self.url = self.params['url']
+ self.proto = self.params['proto']
+ self.username = self.params['username']
+ self.password = self.params['password']
+ self.host = self.params['host']
+ self.port = self.params['port']
+ self.vhost = self.params['vhost']
+ self.queue = self.params['queue']
+ self.headers = self.params['headers']
+ self.cafile = self.params['cafile']
+ self.certfile = self.params['certfile']
+ self.keyfile = self.params['keyfile']
+
+ if self.host is not None:
+ self.build_url()
+
+ if self.cafile is not None:
+ self.append_ssl_certs()
+
+ self.connect_to_rabbitmq()
+
+ def check_required_library(self):
+ if not HAS_PIKA:
+ self.module.fail_json(msg=missing_required_lib("pika"), exception=PIKA_IMP_ERR)
+
+ def check_host_params(self):
+ # Fail if url is specified and other conflicting parameters have been specified
+ if self.params['url'] is not None and any(self.params[k] is not None for k in ['proto', 'host', 'port', 'password', 'username', 'vhost']):
+ self.module.fail_json(msg="url and proto, host, port, vhost, username or password cannot be specified at the same time.")
+
+ # Fail if url not specified and there is a missing parameter to build the url
+ if self.params['url'] is None and any(self.params[k] is None for k in ['proto', 'host', 'port', 'password', 'username', 'vhost']):
+ self.module.fail_json(msg="Connection parameters must be passed via url, or, proto, host, port, vhost, username or password.")
+
+ def append_ssl_certs(self):
+ ssl_options = {}
+ if self.cafile:
+ ssl_options['cafile'] = self.cafile
+ if self.certfile:
+ ssl_options['certfile'] = self.certfile
+ if self.keyfile:
+ ssl_options['keyfile'] = self.keyfile
+
+ self.url = self.url + '?ssl_options=' + urllib_parse.quote(json.dumps(ssl_options))
+
+ @staticmethod
+ def rabbitmq_argument_spec():
+ return dict(
+ url=dict(type='str'),
+ proto=dict(type='str', choices=['amqp', 'amqps']),
+ host=dict(type='str'),
+ port=dict(type='int'),
+ username=dict(type='str'),
+ password=dict(type='str', no_log=True),
+ vhost=dict(type='str'),
+ queue=dict(type='str')
+ )
+
+ ''' Consider some file size limits here '''
+ def _read_file(self, path):
+ try:
+ with open(path, "rb") as file_handle:
+ return file_handle.read()
+ except IOError as e:
+ self.module.fail_json(msg="Unable to open file %s: %s" % (path, to_native(e)))
+
+ @staticmethod
+ def _check_file_mime_type(path):
+ mime = MimeTypes()
+ return mime.guess_type(path)
+
+ def build_url(self):
+ self.url = '{0}://{1}:{2}@{3}:{4}/{5}'.format(self.proto,
+ self.username,
+ self.password,
+ self.host,
+ self.port,
+ self.vhost)
+
+ def connect_to_rabbitmq(self):
+ """
+ Function to connect to rabbitmq using username and password
+ """
+ try:
+ parameters = pika.URLParameters(self.url)
+ except Exception as e:
+ self.module.fail_json(msg="URL malformed: %s" % to_native(e))
+
+ try:
+ self.connection = pika.BlockingConnection(parameters)
+ except Exception as e:
+ self.module.fail_json(msg="Connection issue: %s" % to_native(e))
+
+ try:
+ self.conn_channel = self.connection.channel()
+ except pika.exceptions.AMQPChannelError as e:
+ self.close_connection()
+ self.module.fail_json(msg="Channel issue: %s" % to_native(e))
+
+ def close_connection(self):
+ try:
+ self.connection.close()
+ except pika.exceptions.AMQPConnectionError:
+ pass
+
+ def basic_publish(self):
+ self.content_type = self.params.get("content_type")
+
+ if self.params.get("body") is not None:
+ args = dict(
+ body=self.params.get("body"),
+ exchange=self.params.get("exchange"),
+ routing_key=self.params.get("routing_key"),
+ properties=pika.BasicProperties(content_type=self.content_type, delivery_mode=1, headers=self.headers))
+
+ # If src (file) is defined and content_type is left as default, do a mime lookup on the file
+ if self.params.get("src") is not None and self.content_type == 'text/plain':
+ self.content_type = RabbitClient._check_file_mime_type(self.params.get("src"))[0]
+ self.headers.update(
+ filename=os.path.basename(self.params.get("src"))
+ )
+
+ args = dict(
+ body=self._read_file(self.params.get("src")),
+ exchange=self.params.get("exchange"),
+ routing_key=self.params.get("routing_key"),
+ properties=pika.BasicProperties(content_type=self.content_type,
+ delivery_mode=1,
+ headers=self.headers
+ ))
+ elif self.params.get("src") is not None:
+ args = dict(
+ body=self._read_file(self.params.get("src")),
+ exchange=self.params.get("exchange"),
+ routing_key=self.params.get("routing_key"),
+ properties=pika.BasicProperties(content_type=self.content_type,
+ delivery_mode=1,
+ headers=self.headers
+ ))
+
+ try:
+ # If queue is not defined, RabbitMQ will return the queue name of the automatically generated queue.
+ if self.queue is None:
+ result = self.conn_channel.queue_declare(durable=self.params.get("durable"),
+ exclusive=self.params.get("exclusive"),
+ auto_delete=self.params.get("auto_delete"))
+ self.conn_channel.confirm_delivery()
+ self.queue = result.method.queue
+ else:
+ self.conn_channel.queue_declare(queue=self.queue,
+ durable=self.params.get("durable"),
+ exclusive=self.params.get("exclusive"),
+ auto_delete=self.params.get("auto_delete"))
+ self.conn_channel.confirm_delivery()
+ except Exception as e:
+ self.module.fail_json(msg="Queue declare issue: %s" % to_native(e))
+
+ # https://github.com/ansible/ansible/blob/devel/lib/ansible/module_utils/cloudstack.py#L150
+ if args['routing_key'] is None:
+ args['routing_key'] = self.queue
+
+ if args['exchange'] is None:
+ args['exchange'] = ''
+
+ try:
+ self.conn_channel.basic_publish(**args)
+ return True
+ except pika.exceptions.UnroutableError:
+ return False
diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py
new file mode 120000
index 00000000..f9993bfb
--- /dev/null
+++ b/test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py
@@ -0,0 +1 @@
+azure_rm_mariadbconfiguration_info.py \ No newline at end of file
diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py
new file mode 120000
index 00000000..b8293e64
--- /dev/null
+++ b/test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py
@@ -0,0 +1 @@
+azure_rm_mariadbdatabase_info.py \ No newline at end of file
diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py
new file mode 120000
index 00000000..4311a0c1
--- /dev/null
+++ b/test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py
@@ -0,0 +1 @@
+azure_rm_mariadbfirewallrule_info.py \ No newline at end of file
diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py
new file mode 120000
index 00000000..5f76e0e9
--- /dev/null
+++ b/test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py
@@ -0,0 +1 @@
+azure_rm_mariadbserver_info.py \ No newline at end of file
diff --git a/test/support/integration/plugins/modules/_azure_rm_resource_facts.py b/test/support/integration/plugins/modules/_azure_rm_resource_facts.py
new file mode 120000
index 00000000..710fda10
--- /dev/null
+++ b/test/support/integration/plugins/modules/_azure_rm_resource_facts.py
@@ -0,0 +1 @@
+azure_rm_resource_info.py \ No newline at end of file
diff --git a/test/support/integration/plugins/modules/_azure_rm_webapp_facts.py b/test/support/integration/plugins/modules/_azure_rm_webapp_facts.py
new file mode 120000
index 00000000..ead87c85
--- /dev/null
+++ b/test/support/integration/plugins/modules/_azure_rm_webapp_facts.py
@@ -0,0 +1 @@
+azure_rm_webapp_info.py \ No newline at end of file
diff --git a/test/support/integration/plugins/modules/aws_az_info.py b/test/support/integration/plugins/modules/aws_az_info.py
new file mode 100644
index 00000000..c1efed6f
--- /dev/null
+++ b/test/support/integration/plugins/modules/aws_az_info.py
@@ -0,0 +1,111 @@
+#!/usr/bin/python
+# Copyright (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+ANSIBLE_METADATA = {
+ 'metadata_version': '1.1',
+ 'supported_by': 'community',
+ 'status': ['preview']
+}
+
+DOCUMENTATION = '''
+module: aws_az_info
+short_description: Gather information about availability zones in AWS.
+description:
+ - Gather information about availability zones in AWS.
+ - This module was called C(aws_az_facts) before Ansible 2.9. The usage did not change.
+version_added: '2.5'
+author: 'Henrique Rodrigues (@Sodki)'
+options:
+ filters:
+ description:
+ - A dict of filters to apply. Each dict item consists of a filter key and a filter value. See
+ U(https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeAvailabilityZones.html) for
+ possible filters. Filter names and values are case sensitive. You can also use underscores
+ instead of dashes (-) in the filter keys, which will take precedence in case of conflict.
+ required: false
+ default: {}
+ type: dict
+extends_documentation_fragment:
+ - aws
+ - ec2
+requirements: [botocore, boto3]
+'''
+
+EXAMPLES = '''
+# Note: These examples do not set authentication details, see the AWS Guide for details.
+
+# Gather information about all availability zones
+- aws_az_info:
+
+# Gather information about a single availability zone
+- aws_az_info:
+ filters:
+ zone-name: eu-west-1a
+'''
+
+RETURN = '''
+availability_zones:
+ returned: on success
+ description: >
+ Availability zones that match the provided filters. Each element consists of a dict with all the information
+ related to that available zone.
+ type: list
+ sample: "[
+ {
+ 'messages': [],
+ 'region_name': 'us-west-1',
+ 'state': 'available',
+ 'zone_name': 'us-west-1b'
+ },
+ {
+ 'messages': [],
+ 'region_name': 'us-west-1',
+ 'state': 'available',
+ 'zone_name': 'us-west-1c'
+ }
+ ]"
+'''
+
+from ansible.module_utils.aws.core import AnsibleAWSModule
+from ansible.module_utils.ec2 import AWSRetry, ansible_dict_to_boto3_filter_list, camel_dict_to_snake_dict
+
+try:
+ from botocore.exceptions import ClientError, BotoCoreError
+except ImportError:
+ pass # Handled by AnsibleAWSModule
+
+
+def main():
+ argument_spec = dict(
+ filters=dict(default={}, type='dict')
+ )
+
+ module = AnsibleAWSModule(argument_spec=argument_spec)
+ if module._name == 'aws_az_facts':
+ module.deprecate("The 'aws_az_facts' module has been renamed to 'aws_az_info'",
+ version='2.14', collection_name='ansible.builtin')
+
+ connection = module.client('ec2', retry_decorator=AWSRetry.jittered_backoff())
+
+ # Replace filter key underscores with dashes, for compatibility
+ sanitized_filters = dict((k.replace('_', '-'), v) for k, v in module.params.get('filters').items())
+
+ try:
+ availability_zones = connection.describe_availability_zones(
+ Filters=ansible_dict_to_boto3_filter_list(sanitized_filters)
+ )
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Unable to describe availability zones.")
+
+ # Turn the boto3 result into ansible_friendly_snaked_names
+ snaked_availability_zones = [camel_dict_to_snake_dict(az) for az in availability_zones['AvailabilityZones']]
+
+ module.exit_json(availability_zones=snaked_availability_zones)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/aws_s3.py b/test/support/integration/plugins/modules/aws_s3.py
new file mode 100644
index 00000000..54874f05
--- /dev/null
+++ b/test/support/integration/plugins/modules/aws_s3.py
@@ -0,0 +1,925 @@
+#!/usr/bin/python
+# This file is part of Ansible
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+
+DOCUMENTATION = '''
+---
+module: aws_s3
+short_description: manage objects in S3.
+description:
+ - This module allows the user to manage S3 buckets and the objects within them. Includes support for creating and
+ deleting both objects and buckets, retrieving objects as files or strings and generating download links.
+ This module has a dependency on boto3 and botocore.
+notes:
+ - In 2.4, this module has been renamed from C(s3) into M(aws_s3).
+version_added: "1.1"
+options:
+ bucket:
+ description:
+ - Bucket name.
+ required: true
+ type: str
+ dest:
+ description:
+ - The destination file path when downloading an object/key with a GET operation.
+ version_added: "1.3"
+ type: path
+ encrypt:
+ description:
+ - When set for PUT mode, asks for server-side encryption.
+ default: true
+ version_added: "2.0"
+ type: bool
+ encryption_mode:
+ description:
+ - What encryption mode to use if I(encrypt=true).
+ default: AES256
+ choices:
+ - AES256
+ - aws:kms
+ version_added: "2.7"
+ type: str
+ expiry:
+ description:
+ - Time limit (in seconds) for the URL generated and returned by S3/Walrus when performing a I(mode=put) or I(mode=geturl) operation.
+ default: 600
+ aliases: ['expiration']
+ type: int
+ headers:
+ description:
+ - Custom headers for PUT operation, as a dictionary of 'key=value' and 'key=value,key=value'.
+ version_added: "2.0"
+ type: dict
+ marker:
+ description:
+ - Specifies the key to start with when using list mode. Object keys are returned in alphabetical order, starting with key after the marker in order.
+ version_added: "2.0"
+ type: str
+ max_keys:
+ description:
+ - Max number of results to return in list mode, set this if you want to retrieve fewer than the default 1000 keys.
+ default: 1000
+ version_added: "2.0"
+ type: int
+ metadata:
+ description:
+ - Metadata for PUT operation, as a dictionary of 'key=value' and 'key=value,key=value'.
+ version_added: "1.6"
+ type: dict
+ mode:
+ description:
+ - Switches the module behaviour between put (upload), get (download), geturl (return download url, Ansible 1.3+),
+ getstr (download object as string (1.3+)), list (list keys, Ansible 2.0+), create (bucket), delete (bucket),
+ and delobj (delete object, Ansible 2.0+).
+ required: true
+ choices: ['get', 'put', 'delete', 'create', 'geturl', 'getstr', 'delobj', 'list']
+ type: str
+ object:
+ description:
+ - Keyname of the object inside the bucket. Can be used to create "virtual directories", see examples.
+ type: str
+ permission:
+ description:
+ - This option lets the user set the canned permissions on the object/bucket that are created.
+ The permissions that can be set are C(private), C(public-read), C(public-read-write), C(authenticated-read) for a bucket or
+ C(private), C(public-read), C(public-read-write), C(aws-exec-read), C(authenticated-read), C(bucket-owner-read),
+ C(bucket-owner-full-control) for an object. Multiple permissions can be specified as a list.
+ default: ['private']
+ version_added: "2.0"
+ type: list
+ elements: str
+ prefix:
+ description:
+ - Limits the response to keys that begin with the specified prefix for list mode.
+ default: ""
+ version_added: "2.0"
+ type: str
+ version:
+ description:
+ - Version ID of the object inside the bucket. Can be used to get a specific version of a file if versioning is enabled in the target bucket.
+ version_added: "2.0"
+ type: str
+ overwrite:
+ description:
+ - Force overwrite either locally on the filesystem or remotely with the object/key. Used with PUT and GET operations.
+ Boolean or one of [always, never, different], true is equal to 'always' and false is equal to 'never', new in 2.0.
+ When this is set to 'different', the md5 sum of the local file is compared with the 'ETag' of the object/key in S3.
+ The ETag may or may not be an MD5 digest of the object data. See the ETag response header here
+ U(https://docs.aws.amazon.com/AmazonS3/latest/API/RESTCommonResponseHeaders.html)
+ default: 'always'
+ aliases: ['force']
+ version_added: "1.2"
+ type: str
+ retries:
+ description:
+ - On recoverable failure, how many times to retry before actually failing.
+ default: 0
+ version_added: "2.0"
+ type: int
+ aliases: ['retry']
+ s3_url:
+ description:
+ - S3 URL endpoint for usage with Ceph, Eucalyptus and fakes3 etc. Otherwise assumes AWS.
+ aliases: [ S3_URL ]
+ type: str
+ dualstack:
+ description:
+ - Enables Amazon S3 Dual-Stack Endpoints, allowing S3 communications using both IPv4 and IPv6.
+ - Requires at least botocore version 1.4.45.
+ type: bool
+ default: false
+ version_added: "2.7"
+ rgw:
+ description:
+ - Enable Ceph RGW S3 support. This option requires an explicit url via I(s3_url).
+ default: false
+ version_added: "2.2"
+ type: bool
+ src:
+ description:
+ - The source file path when performing a PUT operation.
+ version_added: "1.3"
+ type: str
+ ignore_nonexistent_bucket:
+ description:
+ - "Overrides initial bucket lookups in case bucket or iam policies are restrictive. Example: a user may have the
+ GetObject permission but no other permissions. In this case using the option mode: get will fail without specifying
+ I(ignore_nonexistent_bucket=true)."
+ version_added: "2.3"
+ type: bool
+ encryption_kms_key_id:
+ description:
+ - KMS key id to use when encrypting objects using I(encrypting=aws:kms). Ignored if I(encryption) is not C(aws:kms)
+ version_added: "2.7"
+ type: str
+requirements: [ "boto3", "botocore" ]
+author:
+ - "Lester Wade (@lwade)"
+ - "Sloane Hertel (@s-hertel)"
+extends_documentation_fragment:
+ - aws
+ - ec2
+'''
+
+EXAMPLES = '''
+- name: Simple PUT operation
+ aws_s3:
+ bucket: mybucket
+ object: /my/desired/key.txt
+ src: /usr/local/myfile.txt
+ mode: put
+
+- name: Simple PUT operation in Ceph RGW S3
+ aws_s3:
+ bucket: mybucket
+ object: /my/desired/key.txt
+ src: /usr/local/myfile.txt
+ mode: put
+ rgw: true
+ s3_url: "http://localhost:8000"
+
+- name: Simple GET operation
+ aws_s3:
+ bucket: mybucket
+ object: /my/desired/key.txt
+ dest: /usr/local/myfile.txt
+ mode: get
+
+- name: Get a specific version of an object.
+ aws_s3:
+ bucket: mybucket
+ object: /my/desired/key.txt
+ version: 48c9ee5131af7a716edc22df9772aa6f
+ dest: /usr/local/myfile.txt
+ mode: get
+
+- name: PUT/upload with metadata
+ aws_s3:
+ bucket: mybucket
+ object: /my/desired/key.txt
+ src: /usr/local/myfile.txt
+ mode: put
+ metadata: 'Content-Encoding=gzip,Cache-Control=no-cache'
+
+- name: PUT/upload with custom headers
+ aws_s3:
+ bucket: mybucket
+ object: /my/desired/key.txt
+ src: /usr/local/myfile.txt
+ mode: put
+ headers: 'x-amz-grant-full-control=emailAddress=owner@example.com'
+
+- name: List keys simple
+ aws_s3:
+ bucket: mybucket
+ mode: list
+
+- name: List keys all options
+ aws_s3:
+ bucket: mybucket
+ mode: list
+ prefix: /my/desired/
+ marker: /my/desired/0023.txt
+ max_keys: 472
+
+- name: Create an empty bucket
+ aws_s3:
+ bucket: mybucket
+ mode: create
+ permission: public-read
+
+- name: Create a bucket with key as directory, in the EU region
+ aws_s3:
+ bucket: mybucket
+ object: /my/directory/path
+ mode: create
+ region: eu-west-1
+
+- name: Delete a bucket and all contents
+ aws_s3:
+ bucket: mybucket
+ mode: delete
+
+- name: GET an object but don't download if the file checksums match. New in 2.0
+ aws_s3:
+ bucket: mybucket
+ object: /my/desired/key.txt
+ dest: /usr/local/myfile.txt
+ mode: get
+ overwrite: different
+
+- name: Delete an object from a bucket
+ aws_s3:
+ bucket: mybucket
+ object: /my/desired/key.txt
+ mode: delobj
+'''
+
+RETURN = '''
+msg:
+ description: Message indicating the status of the operation.
+ returned: always
+ type: str
+ sample: PUT operation complete
+url:
+ description: URL of the object.
+ returned: (for put and geturl operations)
+ type: str
+ sample: https://my-bucket.s3.amazonaws.com/my-key.txt?AWSAccessKeyId=<access-key>&Expires=1506888865&Signature=<signature>
+expiry:
+ description: Number of seconds the presigned url is valid for.
+ returned: (for geturl operation)
+ type: int
+ sample: 600
+contents:
+ description: Contents of the object as string.
+ returned: (for getstr operation)
+ type: str
+ sample: "Hello, world!"
+s3_keys:
+ description: List of object keys.
+ returned: (for list operation)
+ type: list
+ elements: str
+ sample:
+ - prefix1/
+ - prefix1/key1
+ - prefix1/key2
+'''
+
+import mimetypes
+import os
+from ansible.module_utils.six.moves.urllib.parse import urlparse
+from ssl import SSLError
+from ansible.module_utils.basic import to_text, to_native
+from ansible.module_utils.aws.core import AnsibleAWSModule
+from ansible.module_utils.aws.s3 import calculate_etag, HAS_MD5
+from ansible.module_utils.ec2 import get_aws_connection_info, boto3_conn
+
+try:
+ import botocore
+except ImportError:
+ pass # will be detected by imported AnsibleAWSModule
+
+IGNORE_S3_DROP_IN_EXCEPTIONS = ['XNotImplemented', 'NotImplemented']
+
+
+class Sigv4Required(Exception):
+ pass
+
+
+def key_check(module, s3, bucket, obj, version=None, validate=True):
+ exists = True
+ try:
+ if version:
+ s3.head_object(Bucket=bucket, Key=obj, VersionId=version)
+ else:
+ s3.head_object(Bucket=bucket, Key=obj)
+ except botocore.exceptions.ClientError as e:
+ # if a client error is thrown, check if it's a 404 error
+ # if it's a 404 error, then the object does not exist
+ error_code = int(e.response['Error']['Code'])
+ if error_code == 404:
+ exists = False
+ elif error_code == 403 and validate is False:
+ pass
+ else:
+ module.fail_json_aws(e, msg="Failed while looking up object (during key check) %s." % obj)
+ except botocore.exceptions.BotoCoreError as e:
+ module.fail_json_aws(e, msg="Failed while looking up object (during key check) %s." % obj)
+ return exists
+
+
+def etag_compare(module, local_file, s3, bucket, obj, version=None):
+ s3_etag = get_etag(s3, bucket, obj, version=version)
+ local_etag = calculate_etag(module, local_file, s3_etag, s3, bucket, obj, version)
+
+ return s3_etag == local_etag
+
+
+def get_etag(s3, bucket, obj, version=None):
+ if version:
+ key_check = s3.head_object(Bucket=bucket, Key=obj, VersionId=version)
+ else:
+ key_check = s3.head_object(Bucket=bucket, Key=obj)
+ if not key_check:
+ return None
+ return key_check['ETag']
+
+
+def bucket_check(module, s3, bucket, validate=True):
+ exists = True
+ try:
+ s3.head_bucket(Bucket=bucket)
+ except botocore.exceptions.ClientError as e:
+ # If a client error is thrown, then check that it was a 404 error.
+ # If it was a 404 error, then the bucket does not exist.
+ error_code = int(e.response['Error']['Code'])
+ if error_code == 404:
+ exists = False
+ elif error_code == 403 and validate is False:
+ pass
+ else:
+ module.fail_json_aws(e, msg="Failed while looking up bucket (during bucket_check) %s." % bucket)
+ except botocore.exceptions.EndpointConnectionError as e:
+ module.fail_json_aws(e, msg="Invalid endpoint provided")
+ except botocore.exceptions.BotoCoreError as e:
+ module.fail_json_aws(e, msg="Failed while looking up bucket (during bucket_check) %s." % bucket)
+ return exists
+
+
+def create_bucket(module, s3, bucket, location=None):
+ if module.check_mode:
+ module.exit_json(msg="CREATE operation skipped - running in check mode", changed=True)
+ configuration = {}
+ if location not in ('us-east-1', None):
+ configuration['LocationConstraint'] = location
+ try:
+ if len(configuration) > 0:
+ s3.create_bucket(Bucket=bucket, CreateBucketConfiguration=configuration)
+ else:
+ s3.create_bucket(Bucket=bucket)
+ if module.params.get('permission'):
+ # Wait for the bucket to exist before setting ACLs
+ s3.get_waiter('bucket_exists').wait(Bucket=bucket)
+ for acl in module.params.get('permission'):
+ s3.put_bucket_acl(ACL=acl, Bucket=bucket)
+ except botocore.exceptions.ClientError as e:
+ if e.response['Error']['Code'] in IGNORE_S3_DROP_IN_EXCEPTIONS:
+ module.warn("PutBucketAcl is not implemented by your storage provider. Set the permission parameters to the empty list to avoid this warning")
+ else:
+ module.fail_json_aws(e, msg="Failed while creating bucket or setting acl (check that you have CreateBucket and PutBucketAcl permission).")
+ except botocore.exceptions.BotoCoreError as e:
+ module.fail_json_aws(e, msg="Failed while creating bucket or setting acl (check that you have CreateBucket and PutBucketAcl permission).")
+
+ if bucket:
+ return True
+
+
+def paginated_list(s3, **pagination_params):
+ pg = s3.get_paginator('list_objects_v2')
+ for page in pg.paginate(**pagination_params):
+ yield [data['Key'] for data in page.get('Contents', [])]
+
+
+def paginated_versioned_list_with_fallback(s3, **pagination_params):
+ try:
+ versioned_pg = s3.get_paginator('list_object_versions')
+ for page in versioned_pg.paginate(**pagination_params):
+ delete_markers = [{'Key': data['Key'], 'VersionId': data['VersionId']} for data in page.get('DeleteMarkers', [])]
+ current_objects = [{'Key': data['Key'], 'VersionId': data['VersionId']} for data in page.get('Versions', [])]
+ yield delete_markers + current_objects
+ except botocore.exceptions.ClientError as e:
+ if to_text(e.response['Error']['Code']) in IGNORE_S3_DROP_IN_EXCEPTIONS + ['AccessDenied']:
+ for page in paginated_list(s3, **pagination_params):
+ yield [{'Key': data['Key']} for data in page]
+
+
+def list_keys(module, s3, bucket, prefix, marker, max_keys):
+ pagination_params = {'Bucket': bucket}
+ for param_name, param_value in (('Prefix', prefix), ('StartAfter', marker), ('MaxKeys', max_keys)):
+ pagination_params[param_name] = param_value
+ try:
+ keys = sum(paginated_list(s3, **pagination_params), [])
+ module.exit_json(msg="LIST operation complete", s3_keys=keys)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed while listing the keys in the bucket {0}".format(bucket))
+
+
+def delete_bucket(module, s3, bucket):
+ if module.check_mode:
+ module.exit_json(msg="DELETE operation skipped - running in check mode", changed=True)
+ try:
+ exists = bucket_check(module, s3, bucket)
+ if exists is False:
+ return False
+ # if there are contents then we need to delete them before we can delete the bucket
+ for keys in paginated_versioned_list_with_fallback(s3, Bucket=bucket):
+ if keys:
+ s3.delete_objects(Bucket=bucket, Delete={'Objects': keys})
+ s3.delete_bucket(Bucket=bucket)
+ return True
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed while deleting bucket %s." % bucket)
+
+
+def delete_key(module, s3, bucket, obj):
+ if module.check_mode:
+ module.exit_json(msg="DELETE operation skipped - running in check mode", changed=True)
+ try:
+ s3.delete_object(Bucket=bucket, Key=obj)
+ module.exit_json(msg="Object deleted from bucket %s." % (bucket), changed=True)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed while trying to delete %s." % obj)
+
+
+def create_dirkey(module, s3, bucket, obj, encrypt):
+ if module.check_mode:
+ module.exit_json(msg="PUT operation skipped - running in check mode", changed=True)
+ try:
+ params = {'Bucket': bucket, 'Key': obj, 'Body': b''}
+ if encrypt:
+ params['ServerSideEncryption'] = module.params['encryption_mode']
+ if module.params['encryption_kms_key_id'] and module.params['encryption_mode'] == 'aws:kms':
+ params['SSEKMSKeyId'] = module.params['encryption_kms_key_id']
+
+ s3.put_object(**params)
+ for acl in module.params.get('permission'):
+ s3.put_object_acl(ACL=acl, Bucket=bucket, Key=obj)
+ except botocore.exceptions.ClientError as e:
+ if e.response['Error']['Code'] in IGNORE_S3_DROP_IN_EXCEPTIONS:
+ module.warn("PutObjectAcl is not implemented by your storage provider. Set the permissions parameters to the empty list to avoid this warning")
+ else:
+ module.fail_json_aws(e, msg="Failed while creating object %s." % obj)
+ except botocore.exceptions.BotoCoreError as e:
+ module.fail_json_aws(e, msg="Failed while creating object %s." % obj)
+ module.exit_json(msg="Virtual directory %s created in bucket %s" % (obj, bucket), changed=True)
+
+
+def path_check(path):
+ if os.path.exists(path):
+ return True
+ else:
+ return False
+
+
+def option_in_extra_args(option):
+ temp_option = option.replace('-', '').lower()
+
+ allowed_extra_args = {'acl': 'ACL', 'cachecontrol': 'CacheControl', 'contentdisposition': 'ContentDisposition',
+ 'contentencoding': 'ContentEncoding', 'contentlanguage': 'ContentLanguage',
+ 'contenttype': 'ContentType', 'expires': 'Expires', 'grantfullcontrol': 'GrantFullControl',
+ 'grantread': 'GrantRead', 'grantreadacp': 'GrantReadACP', 'grantwriteacp': 'GrantWriteACP',
+ 'metadata': 'Metadata', 'requestpayer': 'RequestPayer', 'serversideencryption': 'ServerSideEncryption',
+ 'storageclass': 'StorageClass', 'ssecustomeralgorithm': 'SSECustomerAlgorithm', 'ssecustomerkey': 'SSECustomerKey',
+ 'ssecustomerkeymd5': 'SSECustomerKeyMD5', 'ssekmskeyid': 'SSEKMSKeyId', 'websiteredirectlocation': 'WebsiteRedirectLocation'}
+
+ if temp_option in allowed_extra_args:
+ return allowed_extra_args[temp_option]
+
+
+def upload_s3file(module, s3, bucket, obj, src, expiry, metadata, encrypt, headers):
+ if module.check_mode:
+ module.exit_json(msg="PUT operation skipped - running in check mode", changed=True)
+ try:
+ extra = {}
+ if encrypt:
+ extra['ServerSideEncryption'] = module.params['encryption_mode']
+ if module.params['encryption_kms_key_id'] and module.params['encryption_mode'] == 'aws:kms':
+ extra['SSEKMSKeyId'] = module.params['encryption_kms_key_id']
+ if metadata:
+ extra['Metadata'] = {}
+
+ # determine object metadata and extra arguments
+ for option in metadata:
+ extra_args_option = option_in_extra_args(option)
+ if extra_args_option is not None:
+ extra[extra_args_option] = metadata[option]
+ else:
+ extra['Metadata'][option] = metadata[option]
+
+ if 'ContentType' not in extra:
+ content_type = mimetypes.guess_type(src)[0]
+ if content_type is None:
+ # s3 default content type
+ content_type = 'binary/octet-stream'
+ extra['ContentType'] = content_type
+
+ s3.upload_file(Filename=src, Bucket=bucket, Key=obj, ExtraArgs=extra)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Unable to complete PUT operation.")
+ try:
+ for acl in module.params.get('permission'):
+ s3.put_object_acl(ACL=acl, Bucket=bucket, Key=obj)
+ except botocore.exceptions.ClientError as e:
+ if e.response['Error']['Code'] in IGNORE_S3_DROP_IN_EXCEPTIONS:
+ module.warn("PutObjectAcl is not implemented by your storage provider. Set the permission parameters to the empty list to avoid this warning")
+ else:
+ module.fail_json_aws(e, msg="Unable to set object ACL")
+ except botocore.exceptions.BotoCoreError as e:
+ module.fail_json_aws(e, msg="Unable to set object ACL")
+ try:
+ url = s3.generate_presigned_url(ClientMethod='put_object',
+ Params={'Bucket': bucket, 'Key': obj},
+ ExpiresIn=expiry)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Unable to generate presigned URL")
+ module.exit_json(msg="PUT operation complete", url=url, changed=True)
+
+
+def download_s3file(module, s3, bucket, obj, dest, retries, version=None):
+ if module.check_mode:
+ module.exit_json(msg="GET operation skipped - running in check mode", changed=True)
+ # retries is the number of loops; range/xrange needs to be one
+ # more to get that count of loops.
+ try:
+ if version:
+ key = s3.get_object(Bucket=bucket, Key=obj, VersionId=version)
+ else:
+ key = s3.get_object(Bucket=bucket, Key=obj)
+ except botocore.exceptions.ClientError as e:
+ if e.response['Error']['Code'] == 'InvalidArgument' and 'require AWS Signature Version 4' in to_text(e):
+ raise Sigv4Required()
+ elif e.response['Error']['Code'] not in ("403", "404"):
+ # AccessDenied errors may be triggered if 1) file does not exist or 2) file exists but
+ # user does not have the s3:GetObject permission. 404 errors are handled by download_file().
+ module.fail_json_aws(e, msg="Could not find the key %s." % obj)
+ except botocore.exceptions.BotoCoreError as e:
+ module.fail_json_aws(e, msg="Could not find the key %s." % obj)
+
+ optional_kwargs = {'ExtraArgs': {'VersionId': version}} if version else {}
+ for x in range(0, retries + 1):
+ try:
+ s3.download_file(bucket, obj, dest, **optional_kwargs)
+ module.exit_json(msg="GET operation complete", changed=True)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ # actually fail on last pass through the loop.
+ if x >= retries:
+ module.fail_json_aws(e, msg="Failed while downloading %s." % obj)
+ # otherwise, try again, this may be a transient timeout.
+ except SSLError as e: # will ClientError catch SSLError?
+ # actually fail on last pass through the loop.
+ if x >= retries:
+ module.fail_json_aws(e, msg="s3 download failed")
+ # otherwise, try again, this may be a transient timeout.
+
+
+def download_s3str(module, s3, bucket, obj, version=None, validate=True):
+ if module.check_mode:
+ module.exit_json(msg="GET operation skipped - running in check mode", changed=True)
+ try:
+ if version:
+ contents = to_native(s3.get_object(Bucket=bucket, Key=obj, VersionId=version)["Body"].read())
+ else:
+ contents = to_native(s3.get_object(Bucket=bucket, Key=obj)["Body"].read())
+ module.exit_json(msg="GET operation complete", contents=contents, changed=True)
+ except botocore.exceptions.ClientError as e:
+ if e.response['Error']['Code'] == 'InvalidArgument' and 'require AWS Signature Version 4' in to_text(e):
+ raise Sigv4Required()
+ else:
+ module.fail_json_aws(e, msg="Failed while getting contents of object %s as a string." % obj)
+ except botocore.exceptions.BotoCoreError as e:
+ module.fail_json_aws(e, msg="Failed while getting contents of object %s as a string." % obj)
+
+
+def get_download_url(module, s3, bucket, obj, expiry, changed=True):
+ try:
+ url = s3.generate_presigned_url(ClientMethod='get_object',
+ Params={'Bucket': bucket, 'Key': obj},
+ ExpiresIn=expiry)
+ module.exit_json(msg="Download url:", url=url, expiry=expiry, changed=changed)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed while getting download url.")
+
+
+def is_fakes3(s3_url):
+ """ Return True if s3_url has scheme fakes3:// """
+ if s3_url is not None:
+ return urlparse(s3_url).scheme in ('fakes3', 'fakes3s')
+ else:
+ return False
+
+
+def get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url, sig_4=False):
+ if s3_url and rgw: # TODO - test this
+ rgw = urlparse(s3_url)
+ params = dict(module=module, conn_type='client', resource='s3', use_ssl=rgw.scheme == 'https', region=location, endpoint=s3_url, **aws_connect_kwargs)
+ elif is_fakes3(s3_url):
+ fakes3 = urlparse(s3_url)
+ port = fakes3.port
+ if fakes3.scheme == 'fakes3s':
+ protocol = "https"
+ if port is None:
+ port = 443
+ else:
+ protocol = "http"
+ if port is None:
+ port = 80
+ params = dict(module=module, conn_type='client', resource='s3', region=location,
+ endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)),
+ use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs)
+ else:
+ params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=s3_url, **aws_connect_kwargs)
+ if module.params['mode'] == 'put' and module.params['encryption_mode'] == 'aws:kms':
+ params['config'] = botocore.client.Config(signature_version='s3v4')
+ elif module.params['mode'] in ('get', 'getstr') and sig_4:
+ params['config'] = botocore.client.Config(signature_version='s3v4')
+ if module.params['dualstack']:
+ dualconf = botocore.client.Config(s3={'use_dualstack_endpoint': True})
+ if 'config' in params:
+ params['config'] = params['config'].merge(dualconf)
+ else:
+ params['config'] = dualconf
+ return boto3_conn(**params)
+
+
+def main():
+ argument_spec = dict(
+ bucket=dict(required=True),
+ dest=dict(default=None, type='path'),
+ encrypt=dict(default=True, type='bool'),
+ encryption_mode=dict(choices=['AES256', 'aws:kms'], default='AES256'),
+ expiry=dict(default=600, type='int', aliases=['expiration']),
+ headers=dict(type='dict'),
+ marker=dict(default=""),
+ max_keys=dict(default=1000, type='int'),
+ metadata=dict(type='dict'),
+ mode=dict(choices=['get', 'put', 'delete', 'create', 'geturl', 'getstr', 'delobj', 'list'], required=True),
+ object=dict(),
+ permission=dict(type='list', default=['private']),
+ version=dict(default=None),
+ overwrite=dict(aliases=['force'], default='always'),
+ prefix=dict(default=""),
+ retries=dict(aliases=['retry'], type='int', default=0),
+ s3_url=dict(aliases=['S3_URL']),
+ dualstack=dict(default='no', type='bool'),
+ rgw=dict(default='no', type='bool'),
+ src=dict(),
+ ignore_nonexistent_bucket=dict(default=False, type='bool'),
+ encryption_kms_key_id=dict()
+ )
+ module = AnsibleAWSModule(
+ argument_spec=argument_spec,
+ supports_check_mode=True,
+ required_if=[['mode', 'put', ['src', 'object']],
+ ['mode', 'get', ['dest', 'object']],
+ ['mode', 'getstr', ['object']],
+ ['mode', 'geturl', ['object']]],
+ )
+
+ bucket = module.params.get('bucket')
+ encrypt = module.params.get('encrypt')
+ expiry = module.params.get('expiry')
+ dest = module.params.get('dest', '')
+ headers = module.params.get('headers')
+ marker = module.params.get('marker')
+ max_keys = module.params.get('max_keys')
+ metadata = module.params.get('metadata')
+ mode = module.params.get('mode')
+ obj = module.params.get('object')
+ version = module.params.get('version')
+ overwrite = module.params.get('overwrite')
+ prefix = module.params.get('prefix')
+ retries = module.params.get('retries')
+ s3_url = module.params.get('s3_url')
+ dualstack = module.params.get('dualstack')
+ rgw = module.params.get('rgw')
+ src = module.params.get('src')
+ ignore_nonexistent_bucket = module.params.get('ignore_nonexistent_bucket')
+
+ object_canned_acl = ["private", "public-read", "public-read-write", "aws-exec-read", "authenticated-read", "bucket-owner-read", "bucket-owner-full-control"]
+ bucket_canned_acl = ["private", "public-read", "public-read-write", "authenticated-read"]
+
+ if overwrite not in ['always', 'never', 'different']:
+ if module.boolean(overwrite):
+ overwrite = 'always'
+ else:
+ overwrite = 'never'
+
+ if overwrite == 'different' and not HAS_MD5:
+ module.fail_json(msg='overwrite=different is unavailable: ETag calculation requires MD5 support')
+
+ region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True)
+
+ if region in ('us-east-1', '', None):
+ # default to US Standard region
+ location = 'us-east-1'
+ else:
+ # Boto uses symbolic names for locations but region strings will
+ # actually work fine for everything except us-east-1 (US Standard)
+ location = region
+
+ if module.params.get('object'):
+ obj = module.params['object']
+ # If there is a top level object, do nothing - if the object starts with /
+ # remove the leading character to maintain compatibility with Ansible versions < 2.4
+ if obj.startswith('/'):
+ obj = obj[1:]
+
+ # Bucket deletion does not require obj. Prevents ambiguity with delobj.
+ if obj and mode == "delete":
+ module.fail_json(msg='Parameter obj cannot be used with mode=delete')
+
+ # allow eucarc environment variables to be used if ansible vars aren't set
+ if not s3_url and 'S3_URL' in os.environ:
+ s3_url = os.environ['S3_URL']
+
+ if dualstack and s3_url is not None and 'amazonaws.com' not in s3_url:
+ module.fail_json(msg='dualstack only applies to AWS S3')
+
+ if dualstack and not module.botocore_at_least('1.4.45'):
+ module.fail_json(msg='dualstack requires botocore >= 1.4.45')
+
+ # rgw requires an explicit url
+ if rgw and not s3_url:
+ module.fail_json(msg='rgw flavour requires s3_url')
+
+ # Look at s3_url and tweak connection settings
+ # if connecting to RGW, Walrus or fakes3
+ if s3_url:
+ for key in ['validate_certs', 'security_token', 'profile_name']:
+ aws_connect_kwargs.pop(key, None)
+ s3 = get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url)
+
+ validate = not ignore_nonexistent_bucket
+
+ # separate types of ACLs
+ bucket_acl = [acl for acl in module.params.get('permission') if acl in bucket_canned_acl]
+ object_acl = [acl for acl in module.params.get('permission') if acl in object_canned_acl]
+ error_acl = [acl for acl in module.params.get('permission') if acl not in bucket_canned_acl and acl not in object_canned_acl]
+ if error_acl:
+ module.fail_json(msg='Unknown permission specified: %s' % error_acl)
+
+ # First, we check to see if the bucket exists, we get "bucket" returned.
+ bucketrtn = bucket_check(module, s3, bucket, validate=validate)
+
+ if validate and mode not in ('create', 'put', 'delete') and not bucketrtn:
+ module.fail_json(msg="Source bucket cannot be found.")
+
+ if mode == 'get':
+ keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate)
+ if keyrtn is False:
+ if version:
+ module.fail_json(msg="Key %s with version id %s does not exist." % (obj, version))
+ else:
+ module.fail_json(msg="Key %s does not exist." % obj)
+
+ if path_check(dest) and overwrite != 'always':
+ if overwrite == 'never':
+ module.exit_json(msg="Local object already exists and overwrite is disabled.", changed=False)
+ if etag_compare(module, dest, s3, bucket, obj, version=version):
+ module.exit_json(msg="Local and remote object are identical, ignoring. Use overwrite=always parameter to force.", changed=False)
+
+ try:
+ download_s3file(module, s3, bucket, obj, dest, retries, version=version)
+ except Sigv4Required:
+ s3 = get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url, sig_4=True)
+ download_s3file(module, s3, bucket, obj, dest, retries, version=version)
+
+ if mode == 'put':
+
+ # if putting an object in a bucket yet to be created, acls for the bucket and/or the object may be specified
+ # these were separated into the variables bucket_acl and object_acl above
+
+ if not path_check(src):
+ module.fail_json(msg="Local object for PUT does not exist")
+
+ if bucketrtn:
+ keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate)
+ else:
+ # If the bucket doesn't exist we should create it.
+ # only use valid bucket acls for create_bucket function
+ module.params['permission'] = bucket_acl
+ create_bucket(module, s3, bucket, location)
+
+ if keyrtn and overwrite != 'always':
+ if overwrite == 'never' or etag_compare(module, src, s3, bucket, obj):
+ # Return the download URL for the existing object
+ get_download_url(module, s3, bucket, obj, expiry, changed=False)
+
+ # only use valid object acls for the upload_s3file function
+ module.params['permission'] = object_acl
+ upload_s3file(module, s3, bucket, obj, src, expiry, metadata, encrypt, headers)
+
+ # Delete an object from a bucket, not the entire bucket
+ if mode == 'delobj':
+ if obj is None:
+ module.fail_json(msg="object parameter is required")
+ if bucket:
+ deletertn = delete_key(module, s3, bucket, obj)
+ if deletertn is True:
+ module.exit_json(msg="Object deleted from bucket %s." % bucket, changed=True)
+ else:
+ module.fail_json(msg="Bucket parameter is required.")
+
+ # Delete an entire bucket, including all objects in the bucket
+ if mode == 'delete':
+ if bucket:
+ deletertn = delete_bucket(module, s3, bucket)
+ if deletertn is True:
+ module.exit_json(msg="Bucket %s and all keys have been deleted." % bucket, changed=True)
+ else:
+ module.fail_json(msg="Bucket parameter is required.")
+
+ # Support for listing a set of keys
+ if mode == 'list':
+ exists = bucket_check(module, s3, bucket)
+
+ # If the bucket does not exist then bail out
+ if not exists:
+ module.fail_json(msg="Target bucket (%s) cannot be found" % bucket)
+
+ list_keys(module, s3, bucket, prefix, marker, max_keys)
+
+ # Need to research how to create directories without "populating" a key, so this should just do bucket creation for now.
+ # WE SHOULD ENABLE SOME WAY OF CREATING AN EMPTY KEY TO CREATE "DIRECTORY" STRUCTURE, AWS CONSOLE DOES THIS.
+ if mode == 'create':
+
+ # if both creating a bucket and putting an object in it, acls for the bucket and/or the object may be specified
+ # these were separated above into the variables bucket_acl and object_acl
+
+ if bucket and not obj:
+ if bucketrtn:
+ module.exit_json(msg="Bucket already exists.", changed=False)
+ else:
+ # only use valid bucket acls when creating the bucket
+ module.params['permission'] = bucket_acl
+ module.exit_json(msg="Bucket created successfully", changed=create_bucket(module, s3, bucket, location))
+ if bucket and obj:
+ if obj.endswith('/'):
+ dirobj = obj
+ else:
+ dirobj = obj + "/"
+ if bucketrtn:
+ if key_check(module, s3, bucket, dirobj):
+ module.exit_json(msg="Bucket %s and key %s already exists." % (bucket, obj), changed=False)
+ else:
+ # setting valid object acls for the create_dirkey function
+ module.params['permission'] = object_acl
+ create_dirkey(module, s3, bucket, dirobj, encrypt)
+ else:
+ # only use valid bucket acls for the create_bucket function
+ module.params['permission'] = bucket_acl
+ created = create_bucket(module, s3, bucket, location)
+ # only use valid object acls for the create_dirkey function
+ module.params['permission'] = object_acl
+ create_dirkey(module, s3, bucket, dirobj, encrypt)
+
+ # Support for grabbing the time-expired URL for an object in S3/Walrus.
+ if mode == 'geturl':
+ if not bucket and not obj:
+ module.fail_json(msg="Bucket and Object parameters must be set")
+
+ keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate)
+ if keyrtn:
+ get_download_url(module, s3, bucket, obj, expiry)
+ else:
+ module.fail_json(msg="Key %s does not exist." % obj)
+
+ if mode == 'getstr':
+ if bucket and obj:
+ keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate)
+ if keyrtn:
+ try:
+ download_s3str(module, s3, bucket, obj, version=version)
+ except Sigv4Required:
+ s3 = get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url, sig_4=True)
+ download_s3str(module, s3, bucket, obj, version=version)
+ elif version is not None:
+ module.fail_json(msg="Key %s with version id %s does not exist." % (obj, version))
+ else:
+ module.fail_json(msg="Key %s does not exist." % obj)
+
+ module.exit_json(failed=False)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_appserviceplan.py b/test/support/integration/plugins/modules/azure_rm_appserviceplan.py
new file mode 100644
index 00000000..ee871c35
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_appserviceplan.py
@@ -0,0 +1,379 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_appserviceplan
+version_added: "2.7"
+short_description: Manage App Service Plan
+description:
+ - Create, update and delete instance of App Service Plan.
+
+options:
+ resource_group:
+ description:
+ - Name of the resource group to which the resource belongs.
+ required: True
+
+ name:
+ description:
+ - Unique name of the app service plan to create or update.
+ required: True
+
+ location:
+ description:
+ - Resource location. If not set, location from the resource group will be used as default.
+
+ sku:
+ description:
+ - The pricing tiers, e.g., C(F1), C(D1), C(B1), C(B2), C(B3), C(S1), C(P1), C(P1V2) etc.
+ - Please see U(https://azure.microsoft.com/en-us/pricing/details/app-service/plans/) for more detail.
+ - For Linux app service plan, please see U(https://azure.microsoft.com/en-us/pricing/details/app-service/linux/) for more detail.
+ is_linux:
+ description:
+ - Describe whether to host webapp on Linux worker.
+ type: bool
+ default: false
+
+ number_of_workers:
+ description:
+ - Describe number of workers to be allocated.
+
+ state:
+ description:
+ - Assert the state of the app service plan.
+ - Use C(present) to create or update an app service plan and C(absent) to delete it.
+ default: present
+ choices:
+ - absent
+ - present
+
+extends_documentation_fragment:
+ - azure
+ - azure_tags
+
+author:
+ - Yunge Zhu (@yungezz)
+
+'''
+
+EXAMPLES = '''
+ - name: Create a windows app service plan
+ azure_rm_appserviceplan:
+ resource_group: myResourceGroup
+ name: myAppPlan
+ location: eastus
+ sku: S1
+
+ - name: Create a linux app service plan
+ azure_rm_appserviceplan:
+ resource_group: myResourceGroup
+ name: myAppPlan
+ location: eastus
+ sku: S1
+ is_linux: true
+ number_of_workers: 1
+
+ - name: update sku of existing windows app service plan
+ azure_rm_appserviceplan:
+ resource_group: myResourceGroup
+ name: myAppPlan
+ location: eastus
+ sku: S2
+'''
+
+RETURN = '''
+azure_appserviceplan:
+ description: Facts about the current state of the app service plan.
+ returned: always
+ type: dict
+ sample: {
+ "id": "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/serverfarms/myAppPlan"
+ }
+'''
+
+import time
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.polling import LROPoller
+ from msrestazure.azure_operation import AzureOperationPoller
+ from msrest.serialization import Model
+ from azure.mgmt.web.models import (
+ app_service_plan, AppServicePlan, SkuDescription
+ )
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+def _normalize_sku(sku):
+ if sku is None:
+ return sku
+
+ sku = sku.upper()
+ if sku == 'FREE':
+ return 'F1'
+ elif sku == 'SHARED':
+ return 'D1'
+ return sku
+
+
+def get_sku_name(tier):
+ tier = tier.upper()
+ if tier == 'F1' or tier == "FREE":
+ return 'FREE'
+ elif tier == 'D1' or tier == "SHARED":
+ return 'SHARED'
+ elif tier in ['B1', 'B2', 'B3', 'BASIC']:
+ return 'BASIC'
+ elif tier in ['S1', 'S2', 'S3']:
+ return 'STANDARD'
+ elif tier in ['P1', 'P2', 'P3']:
+ return 'PREMIUM'
+ elif tier in ['P1V2', 'P2V2', 'P3V2']:
+ return 'PREMIUMV2'
+ else:
+ return None
+
+
+def appserviceplan_to_dict(plan):
+ return dict(
+ id=plan.id,
+ name=plan.name,
+ kind=plan.kind,
+ location=plan.location,
+ reserved=plan.reserved,
+ is_linux=plan.reserved,
+ provisioning_state=plan.provisioning_state,
+ status=plan.status,
+ target_worker_count=plan.target_worker_count,
+ sku=dict(
+ name=plan.sku.name,
+ size=plan.sku.size,
+ tier=plan.sku.tier,
+ family=plan.sku.family,
+ capacity=plan.sku.capacity
+ ),
+ resource_group=plan.resource_group,
+ number_of_sites=plan.number_of_sites,
+ tags=plan.tags if plan.tags else None
+ )
+
+
+class AzureRMAppServicePlans(AzureRMModuleBase):
+ """Configuration class for an Azure RM App Service Plan resource"""
+
+ def __init__(self):
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str',
+ required=True
+ ),
+ location=dict(
+ type='str'
+ ),
+ sku=dict(
+ type='str'
+ ),
+ is_linux=dict(
+ type='bool',
+ default=False
+ ),
+ number_of_workers=dict(
+ type='str'
+ ),
+ state=dict(
+ type='str',
+ default='present',
+ choices=['present', 'absent']
+ )
+ )
+
+ self.resource_group = None
+ self.name = None
+ self.location = None
+
+ self.sku = None
+ self.is_linux = None
+ self.number_of_workers = 1
+
+ self.tags = None
+
+ self.results = dict(
+ changed=False,
+ ansible_facts=dict(azure_appserviceplan=None)
+ )
+ self.state = None
+
+ super(AzureRMAppServicePlans, self).__init__(derived_arg_spec=self.module_arg_spec,
+ supports_check_mode=True,
+ supports_tags=True)
+
+ def exec_module(self, **kwargs):
+ """Main module execution method"""
+
+ for key in list(self.module_arg_spec.keys()) + ['tags']:
+ if kwargs[key]:
+ setattr(self, key, kwargs[key])
+
+ old_response = None
+ response = None
+ to_be_updated = False
+
+ # set location
+ resource_group = self.get_resource_group(self.resource_group)
+ if not self.location:
+ self.location = resource_group.location
+
+ # get app service plan
+ old_response = self.get_plan()
+
+ # if not existing
+ if not old_response:
+ self.log("App Service plan doesn't exist")
+
+ if self.state == "present":
+ to_be_updated = True
+
+ if not self.sku:
+ self.fail('Please specify sku in plan when creation')
+
+ else:
+ # existing app service plan, do update
+ self.log("App Service Plan already exists")
+
+ if self.state == 'present':
+ self.log('Result: {0}'.format(old_response))
+
+ update_tags, newtags = self.update_tags(old_response.get('tags', dict()))
+
+ if update_tags:
+ to_be_updated = True
+ self.tags = newtags
+
+ # check if sku changed
+ if self.sku and _normalize_sku(self.sku) != old_response['sku']['size']:
+ to_be_updated = True
+
+ # check if number_of_workers changed
+ if self.number_of_workers and int(self.number_of_workers) != old_response['sku']['capacity']:
+ to_be_updated = True
+
+ if self.is_linux and self.is_linux != old_response['reserved']:
+ self.fail("Operation not allowed: cannot update reserved of app service plan.")
+
+ if old_response:
+ self.results['id'] = old_response['id']
+
+ if to_be_updated:
+ self.log('Need to Create/Update app service plan')
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ response = self.create_or_update_plan()
+ self.results['id'] = response['id']
+
+ if self.state == 'absent' and old_response:
+ self.log("Delete app service plan")
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ self.delete_plan()
+
+ self.log('App service plan instance deleted')
+
+ return self.results
+
+ def get_plan(self):
+ '''
+ Gets app service plan
+ :return: deserialized app service plan dictionary
+ '''
+ self.log("Get App Service Plan {0}".format(self.name))
+
+ try:
+ response = self.web_client.app_service_plans.get(self.resource_group, self.name)
+ if response:
+ self.log("Response : {0}".format(response))
+ self.log("App Service Plan : {0} found".format(response.name))
+
+ return appserviceplan_to_dict(response)
+ except CloudError as ex:
+ self.log("Didn't find app service plan {0} in resource group {1}".format(self.name, self.resource_group))
+
+ return False
+
+ def create_or_update_plan(self):
+ '''
+ Creates app service plan
+ :return: deserialized app service plan dictionary
+ '''
+ self.log("Create App Service Plan {0}".format(self.name))
+
+ try:
+ # normalize sku
+ sku = _normalize_sku(self.sku)
+
+ sku_def = SkuDescription(tier=get_sku_name(
+ sku), name=sku, capacity=self.number_of_workers)
+ plan_def = AppServicePlan(
+ location=self.location, app_service_plan_name=self.name, sku=sku_def, reserved=self.is_linux, tags=self.tags if self.tags else None)
+
+ response = self.web_client.app_service_plans.create_or_update(self.resource_group, self.name, plan_def)
+
+ if isinstance(response, LROPoller) or isinstance(response, AzureOperationPoller):
+ response = self.get_poller_result(response)
+
+ self.log("Response : {0}".format(response))
+
+ return appserviceplan_to_dict(response)
+ except CloudError as ex:
+ self.fail("Failed to create app service plan {0} in resource group {1}: {2}".format(self.name, self.resource_group, str(ex)))
+
+ def delete_plan(self):
+ '''
+ Deletes specified App service plan in the specified subscription and resource group.
+
+ :return: True
+ '''
+ self.log("Deleting the App service plan {0}".format(self.name))
+ try:
+ response = self.web_client.app_service_plans.delete(resource_group_name=self.resource_group,
+ name=self.name)
+ except CloudError as e:
+ self.log('Error attempting to delete App service plan.')
+ self.fail(
+ "Error deleting the App service plan : {0}".format(str(e)))
+
+ return True
+
+
+def main():
+ """Main execution"""
+ AzureRMAppServicePlans()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_functionapp.py b/test/support/integration/plugins/modules/azure_rm_functionapp.py
new file mode 100644
index 00000000..0c372a88
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_functionapp.py
@@ -0,0 +1,421 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2016, Thomas Stringer <tomstr@microsoft.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_functionapp
+version_added: "2.4"
+short_description: Manage Azure Function Apps
+description:
+ - Create, update or delete an Azure Function App.
+options:
+ resource_group:
+ description:
+ - Name of resource group.
+ required: true
+ aliases:
+ - resource_group_name
+ name:
+ description:
+ - Name of the Azure Function App.
+ required: true
+ location:
+ description:
+ - Valid Azure location. Defaults to location of the resource group.
+ plan:
+ description:
+ - App service plan.
+ - It can be name of existing app service plan in same resource group as function app.
+ - It can be resource id of existing app service plan.
+ - Resource id. For example /subscriptions/<subs_id>/resourceGroups/<resource_group>/providers/Microsoft.Web/serverFarms/<plan_name>.
+ - It can be a dict which contains C(name), C(resource_group).
+ - C(name). Name of app service plan.
+ - C(resource_group). Resource group name of app service plan.
+ version_added: "2.8"
+ container_settings:
+ description: Web app container settings.
+ suboptions:
+ name:
+ description:
+ - Name of container. For example "imagename:tag".
+ registry_server_url:
+ description:
+ - Container registry server url. For example C(mydockerregistry.io).
+ registry_server_user:
+ description:
+ - The container registry server user name.
+ registry_server_password:
+ description:
+ - The container registry server password.
+ version_added: "2.8"
+ storage_account:
+ description:
+ - Name of the storage account to use.
+ required: true
+ aliases:
+ - storage
+ - storage_account_name
+ app_settings:
+ description:
+ - Dictionary containing application settings.
+ state:
+ description:
+ - Assert the state of the Function App. Use C(present) to create or update a Function App and C(absent) to delete.
+ default: present
+ choices:
+ - absent
+ - present
+
+extends_documentation_fragment:
+ - azure
+ - azure_tags
+
+author:
+ - Thomas Stringer (@trstringer)
+'''
+
+EXAMPLES = '''
+- name: Create a function app
+ azure_rm_functionapp:
+ resource_group: myResourceGroup
+ name: myFunctionApp
+ storage_account: myStorageAccount
+
+- name: Create a function app with app settings
+ azure_rm_functionapp:
+ resource_group: myResourceGroup
+ name: myFunctionApp
+ storage_account: myStorageAccount
+ app_settings:
+ setting1: value1
+ setting2: value2
+
+- name: Create container based function app
+ azure_rm_functionapp:
+ resource_group: myResourceGroup
+ name: myFunctionApp
+ storage_account: myStorageAccount
+ plan:
+ resource_group: myResourceGroup
+ name: myAppPlan
+ container_settings:
+ name: httpd
+ registry_server_url: index.docker.io
+
+- name: Delete a function app
+ azure_rm_functionapp:
+ resource_group: myResourceGroup
+ name: myFunctionApp
+ state: absent
+'''
+
+RETURN = '''
+state:
+ description:
+ - Current state of the Azure Function App.
+ returned: success
+ type: dict
+ example:
+ id: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/myFunctionApp
+ name: myfunctionapp
+ kind: functionapp
+ location: East US
+ type: Microsoft.Web/sites
+ state: Running
+ host_names:
+ - myfunctionapp.azurewebsites.net
+ repository_site_name: myfunctionapp
+ usage_state: Normal
+ enabled: true
+ enabled_host_names:
+ - myfunctionapp.azurewebsites.net
+ - myfunctionapp.scm.azurewebsites.net
+ availability_state: Normal
+ host_name_ssl_states:
+ - name: myfunctionapp.azurewebsites.net
+ ssl_state: Disabled
+ host_type: Standard
+ - name: myfunctionapp.scm.azurewebsites.net
+ ssl_state: Disabled
+ host_type: Repository
+ server_farm_id: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/serverfarms/EastUSPlan
+ reserved: false
+ last_modified_time_utc: 2017-08-22T18:54:01.190Z
+ scm_site_also_stopped: false
+ client_affinity_enabled: true
+ client_cert_enabled: false
+ host_names_disabled: false
+ outbound_ip_addresses: ............
+ container_size: 1536
+ daily_memory_time_quota: 0
+ resource_group: myResourceGroup
+ default_host_name: myfunctionapp.azurewebsites.net
+''' # NOQA
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from azure.mgmt.web.models import (
+ site_config, app_service_plan, Site, SiteConfig, NameValuePair, SiteSourceControl,
+ AppServicePlan, SkuDescription
+ )
+ from azure.mgmt.resource.resources import ResourceManagementClient
+ from msrest.polling import LROPoller
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+container_settings_spec = dict(
+ name=dict(type='str', required=True),
+ registry_server_url=dict(type='str'),
+ registry_server_user=dict(type='str'),
+ registry_server_password=dict(type='str', no_log=True)
+)
+
+
+class AzureRMFunctionApp(AzureRMModuleBase):
+
+ def __init__(self):
+
+ self.module_arg_spec = dict(
+ resource_group=dict(type='str', required=True, aliases=['resource_group_name']),
+ name=dict(type='str', required=True),
+ state=dict(type='str', default='present', choices=['present', 'absent']),
+ location=dict(type='str'),
+ storage_account=dict(
+ type='str',
+ aliases=['storage', 'storage_account_name']
+ ),
+ app_settings=dict(type='dict'),
+ plan=dict(
+ type='raw'
+ ),
+ container_settings=dict(
+ type='dict',
+ options=container_settings_spec
+ )
+ )
+
+ self.results = dict(
+ changed=False,
+ state=dict()
+ )
+
+ self.resource_group = None
+ self.name = None
+ self.state = None
+ self.location = None
+ self.storage_account = None
+ self.app_settings = None
+ self.plan = None
+ self.container_settings = None
+
+ required_if = [('state', 'present', ['storage_account'])]
+
+ super(AzureRMFunctionApp, self).__init__(
+ self.module_arg_spec,
+ supports_check_mode=True,
+ required_if=required_if
+ )
+
+ def exec_module(self, **kwargs):
+
+ for key in self.module_arg_spec:
+ setattr(self, key, kwargs[key])
+ if self.app_settings is None:
+ self.app_settings = dict()
+
+ try:
+ resource_group = self.rm_client.resource_groups.get(self.resource_group)
+ except CloudError:
+ self.fail('Unable to retrieve resource group')
+
+ self.location = self.location or resource_group.location
+
+ try:
+ function_app = self.web_client.web_apps.get(
+ resource_group_name=self.resource_group,
+ name=self.name
+ )
+ # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError
+ exists = function_app is not None
+ except CloudError as exc:
+ exists = False
+
+ if self.state == 'absent':
+ if exists:
+ if self.check_mode:
+ self.results['changed'] = True
+ return self.results
+ try:
+ self.web_client.web_apps.delete(
+ resource_group_name=self.resource_group,
+ name=self.name
+ )
+ self.results['changed'] = True
+ except CloudError as exc:
+ self.fail('Failure while deleting web app: {0}'.format(exc))
+ else:
+ self.results['changed'] = False
+ else:
+ kind = 'functionapp'
+ linux_fx_version = None
+ if self.container_settings and self.container_settings.get('name'):
+ kind = 'functionapp,linux,container'
+ linux_fx_version = 'DOCKER|'
+ if self.container_settings.get('registry_server_url'):
+ self.app_settings['DOCKER_REGISTRY_SERVER_URL'] = 'https://' + self.container_settings['registry_server_url']
+ linux_fx_version += self.container_settings['registry_server_url'] + '/'
+ linux_fx_version += self.container_settings['name']
+ if self.container_settings.get('registry_server_user'):
+ self.app_settings['DOCKER_REGISTRY_SERVER_USERNAME'] = self.container_settings.get('registry_server_user')
+
+ if self.container_settings.get('registry_server_password'):
+ self.app_settings['DOCKER_REGISTRY_SERVER_PASSWORD'] = self.container_settings.get('registry_server_password')
+
+ if not self.plan and function_app:
+ self.plan = function_app.server_farm_id
+
+ if not exists:
+ function_app = Site(
+ location=self.location,
+ kind=kind,
+ site_config=SiteConfig(
+ app_settings=self.aggregated_app_settings(),
+ scm_type='LocalGit'
+ )
+ )
+ self.results['changed'] = True
+ else:
+ self.results['changed'], function_app = self.update(function_app)
+
+ # get app service plan
+ if self.plan:
+ if isinstance(self.plan, dict):
+ self.plan = "/subscriptions/{0}/resourceGroups/{1}/providers/Microsoft.Web/serverfarms/{2}".format(
+ self.subscription_id,
+ self.plan.get('resource_group', self.resource_group),
+ self.plan.get('name')
+ )
+ function_app.server_farm_id = self.plan
+
+ # set linux fx version
+ if linux_fx_version:
+ function_app.site_config.linux_fx_version = linux_fx_version
+
+ if self.check_mode:
+ self.results['state'] = function_app.as_dict()
+ elif self.results['changed']:
+ try:
+ new_function_app = self.web_client.web_apps.create_or_update(
+ resource_group_name=self.resource_group,
+ name=self.name,
+ site_envelope=function_app
+ ).result()
+ self.results['state'] = new_function_app.as_dict()
+ except CloudError as exc:
+ self.fail('Error creating or updating web app: {0}'.format(exc))
+
+ return self.results
+
+ def update(self, source_function_app):
+ """Update the Site object if there are any changes"""
+
+ source_app_settings = self.web_client.web_apps.list_application_settings(
+ resource_group_name=self.resource_group,
+ name=self.name
+ )
+
+ changed, target_app_settings = self.update_app_settings(source_app_settings.properties)
+
+ source_function_app.site_config = SiteConfig(
+ app_settings=target_app_settings,
+ scm_type='LocalGit'
+ )
+
+ return changed, source_function_app
+
+ def update_app_settings(self, source_app_settings):
+ """Update app settings"""
+
+ target_app_settings = self.aggregated_app_settings()
+ target_app_settings_dict = dict([(i.name, i.value) for i in target_app_settings])
+ return target_app_settings_dict != source_app_settings, target_app_settings
+
+ def necessary_functionapp_settings(self):
+ """Construct the necessary app settings required for an Azure Function App"""
+
+ function_app_settings = []
+
+ if self.container_settings is None:
+ for key in ['AzureWebJobsStorage', 'WEBSITE_CONTENTAZUREFILECONNECTIONSTRING', 'AzureWebJobsDashboard']:
+ function_app_settings.append(NameValuePair(name=key, value=self.storage_connection_string))
+ function_app_settings.append(NameValuePair(name='FUNCTIONS_EXTENSION_VERSION', value='~1'))
+ function_app_settings.append(NameValuePair(name='WEBSITE_NODE_DEFAULT_VERSION', value='6.5.0'))
+ function_app_settings.append(NameValuePair(name='WEBSITE_CONTENTSHARE', value=self.name))
+ else:
+ function_app_settings.append(NameValuePair(name='FUNCTIONS_EXTENSION_VERSION', value='~2'))
+ function_app_settings.append(NameValuePair(name='WEBSITES_ENABLE_APP_SERVICE_STORAGE', value=False))
+ function_app_settings.append(NameValuePair(name='AzureWebJobsStorage', value=self.storage_connection_string))
+
+ return function_app_settings
+
+ def aggregated_app_settings(self):
+ """Combine both system and user app settings"""
+
+ function_app_settings = self.necessary_functionapp_settings()
+ for app_setting_key in self.app_settings:
+ found_setting = None
+ for s in function_app_settings:
+ if s.name == app_setting_key:
+ found_setting = s
+ break
+ if found_setting:
+ found_setting.value = self.app_settings[app_setting_key]
+ else:
+ function_app_settings.append(NameValuePair(
+ name=app_setting_key,
+ value=self.app_settings[app_setting_key]
+ ))
+ return function_app_settings
+
+ @property
+ def storage_connection_string(self):
+ """Construct the storage account connection string"""
+
+ return 'DefaultEndpointsProtocol=https;AccountName={0};AccountKey={1}'.format(
+ self.storage_account,
+ self.storage_key
+ )
+
+ @property
+ def storage_key(self):
+ """Retrieve the storage account key"""
+
+ return self.storage_client.storage_accounts.list_keys(
+ resource_group_name=self.resource_group,
+ account_name=self.storage_account
+ ).keys[0].value
+
+
+def main():
+ """Main function execution"""
+
+ AzureRMFunctionApp()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_functionapp_info.py b/test/support/integration/plugins/modules/azure_rm_functionapp_info.py
new file mode 100644
index 00000000..40672f95
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_functionapp_info.py
@@ -0,0 +1,207 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2016 Thomas Stringer, <tomstr@microsoft.com>
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_functionapp_info
+version_added: "2.9"
+short_description: Get Azure Function App facts
+description:
+ - Get facts for one Azure Function App or all Function Apps within a resource group.
+options:
+ name:
+ description:
+ - Only show results for a specific Function App.
+ resource_group:
+ description:
+ - Limit results to a resource group. Required when filtering by name.
+ aliases:
+ - resource_group_name
+ tags:
+ description:
+ - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'.
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Thomas Stringer (@trstringer)
+'''
+
+EXAMPLES = '''
+ - name: Get facts for one Function App
+ azure_rm_functionapp_info:
+ resource_group: myResourceGroup
+ name: myfunctionapp
+
+ - name: Get facts for all Function Apps in a resource group
+ azure_rm_functionapp_info:
+ resource_group: myResourceGroup
+
+ - name: Get facts for all Function Apps by tags
+ azure_rm_functionapp_info:
+ tags:
+ - testing
+'''
+
+RETURN = '''
+azure_functionapps:
+ description:
+ - List of Azure Function Apps dicts.
+ returned: always
+ type: list
+ example:
+ id: /subscriptions/.../resourceGroups/ansible-rg/providers/Microsoft.Web/sites/myfunctionapp
+ name: myfunctionapp
+ kind: functionapp
+ location: East US
+ type: Microsoft.Web/sites
+ state: Running
+ host_names:
+ - myfunctionapp.azurewebsites.net
+ repository_site_name: myfunctionapp
+ usage_state: Normal
+ enabled: true
+ enabled_host_names:
+ - myfunctionapp.azurewebsites.net
+ - myfunctionapp.scm.azurewebsites.net
+ availability_state: Normal
+ host_name_ssl_states:
+ - name: myfunctionapp.azurewebsites.net
+ ssl_state: Disabled
+ host_type: Standard
+ - name: myfunctionapp.scm.azurewebsites.net
+ ssl_state: Disabled
+ host_type: Repository
+ server_farm_id: /subscriptions/.../resourceGroups/ansible-rg/providers/Microsoft.Web/serverfarms/EastUSPlan
+ reserved: false
+ last_modified_time_utc: 2017-08-22T18:54:01.190Z
+ scm_site_also_stopped: false
+ client_affinity_enabled: true
+ client_cert_enabled: false
+ host_names_disabled: false
+ outbound_ip_addresses: ............
+ container_size: 1536
+ daily_memory_time_quota: 0
+ resource_group: myResourceGroup
+ default_host_name: myfunctionapp.azurewebsites.net
+'''
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+except Exception:
+ # This is handled in azure_rm_common
+ pass
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+
+class AzureRMFunctionAppInfo(AzureRMModuleBase):
+ def __init__(self):
+
+ self.module_arg_spec = dict(
+ name=dict(type='str'),
+ resource_group=dict(type='str', aliases=['resource_group_name']),
+ tags=dict(type='list'),
+ )
+
+ self.results = dict(
+ changed=False,
+ ansible_info=dict(azure_functionapps=[])
+ )
+
+ self.name = None
+ self.resource_group = None
+ self.tags = None
+
+ super(AzureRMFunctionAppInfo, self).__init__(
+ self.module_arg_spec,
+ supports_tags=False,
+ facts_module=True
+ )
+
+ def exec_module(self, **kwargs):
+
+ is_old_facts = self.module._name == 'azure_rm_functionapp_facts'
+ if is_old_facts:
+ self.module.deprecate("The 'azure_rm_functionapp_facts' module has been renamed to 'azure_rm_functionapp_info'",
+ version='2.13', collection_name='ansible.builtin')
+
+ for key in self.module_arg_spec:
+ setattr(self, key, kwargs[key])
+
+ if self.name and not self.resource_group:
+ self.fail("Parameter error: resource group required when filtering by name.")
+
+ if self.name:
+ self.results['ansible_info']['azure_functionapps'] = self.get_functionapp()
+ elif self.resource_group:
+ self.results['ansible_info']['azure_functionapps'] = self.list_resource_group()
+ else:
+ self.results['ansible_info']['azure_functionapps'] = self.list_all()
+
+ return self.results
+
+ def get_functionapp(self):
+ self.log('Get properties for Function App {0}'.format(self.name))
+ function_app = None
+ result = []
+
+ try:
+ function_app = self.web_client.web_apps.get(
+ self.resource_group,
+ self.name
+ )
+ except CloudError:
+ pass
+
+ if function_app and self.has_tags(function_app.tags, self.tags):
+ result = function_app.as_dict()
+
+ return [result]
+
+ def list_resource_group(self):
+ self.log('List items')
+ try:
+ response = self.web_client.web_apps.list_by_resource_group(self.resource_group)
+ except Exception as exc:
+ self.fail("Error listing for resource group {0} - {1}".format(self.resource_group, str(exc)))
+
+ results = []
+ for item in response:
+ if self.has_tags(item.tags, self.tags):
+ results.append(item.as_dict())
+ return results
+
+ def list_all(self):
+ self.log('List all items')
+ try:
+ response = self.web_client.web_apps.list_by_resource_group(self.resource_group)
+ except Exception as exc:
+ self.fail("Error listing all items - {0}".format(str(exc)))
+
+ results = []
+ for item in response:
+ if self.has_tags(item.tags, self.tags):
+ results.append(item.as_dict())
+ return results
+
+
+def main():
+ AzureRMFunctionAppInfo()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py
new file mode 100644
index 00000000..212cf795
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py
@@ -0,0 +1,241 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2019 Zim Kalinowski, (@zikalino)
+# Copyright (c) 2019 Matti Ranta, (@techknowlogick)
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_mariadbconfiguration
+version_added: "2.8"
+short_description: Manage Configuration instance
+description:
+ - Create, update and delete instance of Configuration.
+
+options:
+ resource_group:
+ description:
+ - The name of the resource group that contains the resource.
+ required: True
+ server_name:
+ description:
+ - The name of the server.
+ required: True
+ name:
+ description:
+ - The name of the server configuration.
+ required: True
+ value:
+ description:
+ - Value of the configuration.
+ state:
+ description:
+ - Assert the state of the MariaDB configuration. Use C(present) to update setting, or C(absent) to reset to default value.
+ default: present
+ choices:
+ - absent
+ - present
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Zim Kalinowski (@zikalino)
+ - Matti Ranta (@techknowlogick)
+'''
+
+EXAMPLES = '''
+ - name: Update SQL Server setting
+ azure_rm_mariadbconfiguration:
+ resource_group: myResourceGroup
+ server_name: myServer
+ name: event_scheduler
+ value: "ON"
+'''
+
+RETURN = '''
+id:
+ description:
+ - Resource ID.
+ returned: always
+ type: str
+ sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/myServer/confi
+ gurations/event_scheduler"
+'''
+
+import time
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.polling import LROPoller
+ from azure.mgmt.rdbms.mysql import MariaDBManagementClient
+ from msrest.serialization import Model
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class Actions:
+ NoAction, Create, Update, Delete = range(4)
+
+
+class AzureRMMariaDbConfiguration(AzureRMModuleBase):
+
+ def __init__(self):
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ server_name=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str',
+ required=True
+ ),
+ value=dict(
+ type='str'
+ ),
+ state=dict(
+ type='str',
+ default='present',
+ choices=['present', 'absent']
+ )
+ )
+
+ self.resource_group = None
+ self.server_name = None
+ self.name = None
+ self.value = None
+
+ self.results = dict(changed=False)
+ self.state = None
+ self.to_do = Actions.NoAction
+
+ super(AzureRMMariaDbConfiguration, self).__init__(derived_arg_spec=self.module_arg_spec,
+ supports_check_mode=True,
+ supports_tags=False)
+
+ def exec_module(self, **kwargs):
+
+ for key in list(self.module_arg_spec.keys()):
+ if hasattr(self, key):
+ setattr(self, key, kwargs[key])
+
+ old_response = None
+ response = None
+
+ old_response = self.get_configuration()
+
+ if not old_response:
+ self.log("Configuration instance doesn't exist")
+ if self.state == 'absent':
+ self.log("Old instance didn't exist")
+ else:
+ self.to_do = Actions.Create
+ else:
+ self.log("Configuration instance already exists")
+ if self.state == 'absent' and old_response['source'] == 'user-override':
+ self.to_do = Actions.Delete
+ elif self.state == 'present':
+ self.log("Need to check if Configuration instance has to be deleted or may be updated")
+ if self.value != old_response.get('value'):
+ self.to_do = Actions.Update
+
+ if (self.to_do == Actions.Create) or (self.to_do == Actions.Update):
+ self.log("Need to Create / Update the Configuration instance")
+
+ if self.check_mode:
+ self.results['changed'] = True
+ return self.results
+
+ response = self.create_update_configuration()
+
+ self.results['changed'] = True
+ self.log("Creation / Update done")
+ elif self.to_do == Actions.Delete:
+ self.log("Configuration instance deleted")
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ self.delete_configuration()
+ else:
+ self.log("Configuration instance unchanged")
+ self.results['changed'] = False
+ response = old_response
+
+ if response:
+ self.results["id"] = response["id"]
+
+ return self.results
+
+ def create_update_configuration(self):
+ self.log("Creating / Updating the Configuration instance {0}".format(self.name))
+
+ try:
+ response = self.mariadb_client.configurations.create_or_update(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ configuration_name=self.name,
+ value=self.value,
+ source='user-override')
+ if isinstance(response, LROPoller):
+ response = self.get_poller_result(response)
+
+ except CloudError as exc:
+ self.log('Error attempting to create the Configuration instance.')
+ self.fail("Error creating the Configuration instance: {0}".format(str(exc)))
+ return response.as_dict()
+
+ def delete_configuration(self):
+ self.log("Deleting the Configuration instance {0}".format(self.name))
+ try:
+ response = self.mariadb_client.configurations.create_or_update(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ configuration_name=self.name,
+ source='system-default')
+ except CloudError as e:
+ self.log('Error attempting to delete the Configuration instance.')
+ self.fail("Error deleting the Configuration instance: {0}".format(str(e)))
+
+ return True
+
+ def get_configuration(self):
+ self.log("Checking if the Configuration instance {0} is present".format(self.name))
+ found = False
+ try:
+ response = self.mariadb_client.configurations.get(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ configuration_name=self.name)
+ found = True
+ self.log("Response : {0}".format(response))
+ self.log("Configuration instance : {0} found".format(response.name))
+ except CloudError as e:
+ self.log('Did not find the Configuration instance.')
+ if found is True:
+ return response.as_dict()
+
+ return False
+
+
+def main():
+ """Main execution"""
+ AzureRMMariaDbConfiguration()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py
new file mode 100644
index 00000000..3faac5eb
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py
@@ -0,0 +1,217 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2019 Zim Kalinowski, (@zikalino)
+# Copyright (c) 2019 Matti Ranta, (@techknowlogick)
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_mariadbconfiguration_info
+version_added: "2.9"
+short_description: Get Azure MariaDB Configuration facts
+description:
+ - Get facts of Azure MariaDB Configuration.
+
+options:
+ resource_group:
+ description:
+ - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal.
+ required: True
+ type: str
+ server_name:
+ description:
+ - The name of the server.
+ required: True
+ type: str
+ name:
+ description:
+ - Setting name.
+ type: str
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Zim Kalinowski (@zikalino)
+ - Matti Ranta (@techknowlogick)
+
+'''
+
+EXAMPLES = '''
+ - name: Get specific setting of MariaDB Server
+ azure_rm_mariadbconfiguration_info:
+ resource_group: myResourceGroup
+ server_name: testserver
+ name: deadlock_timeout
+
+ - name: Get all settings of MariaDB Server
+ azure_rm_mariadbconfiguration_info:
+ resource_group: myResourceGroup
+ server_name: server_name
+'''
+
+RETURN = '''
+settings:
+ description:
+ - A list of dictionaries containing MariaDB Server settings.
+ returned: always
+ type: complex
+ contains:
+ id:
+ description:
+ - Setting resource ID.
+ returned: always
+ type: str
+ sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testserver
+ /configurations/deadlock_timeout"
+ name:
+ description:
+ - Setting name.
+ returned: always
+ type: str
+ sample: deadlock_timeout
+ value:
+ description:
+ - Setting value.
+ returned: always
+ type: raw
+ sample: 1000
+ description:
+ description:
+ - Description of the configuration.
+ returned: always
+ type: str
+ sample: Deadlock timeout.
+ source:
+ description:
+ - Source of the configuration.
+ returned: always
+ type: str
+ sample: system-default
+'''
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrestazure.azure_operation import AzureOperationPoller
+ from azure.mgmt.rdbms.mariadb import MariaDBManagementClient
+ from msrest.serialization import Model
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class AzureRMMariaDbConfigurationInfo(AzureRMModuleBase):
+ def __init__(self):
+ # define user inputs into argument
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ server_name=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str'
+ )
+ )
+ # store the results of the module operation
+ self.results = dict(changed=False)
+ self.mgmt_client = None
+ self.resource_group = None
+ self.server_name = None
+ self.name = None
+ super(AzureRMMariaDbConfigurationInfo, self).__init__(self.module_arg_spec, supports_tags=False)
+
+ def exec_module(self, **kwargs):
+ is_old_facts = self.module._name == 'azure_rm_mariadbconfiguration_facts'
+ if is_old_facts:
+ self.module.deprecate("The 'azure_rm_mariadbconfiguration_facts' module has been renamed to 'azure_rm_mariadbconfiguration_info'",
+ version='2.13', collection_name='ansible.builtin')
+
+ for key in self.module_arg_spec:
+ setattr(self, key, kwargs[key])
+ self.mgmt_client = self.get_mgmt_svc_client(MariaDBManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+
+ if self.name is not None:
+ self.results['settings'] = self.get()
+ else:
+ self.results['settings'] = self.list_by_server()
+ return self.results
+
+ def get(self):
+ '''
+ Gets facts of the specified MariaDB Configuration.
+
+ :return: deserialized MariaDB Configurationinstance state dictionary
+ '''
+ response = None
+ results = []
+ try:
+ response = self.mgmt_client.configurations.get(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ configuration_name=self.name)
+ self.log("Response : {0}".format(response))
+ except CloudError as e:
+ self.log('Could not get facts for Configurations.')
+
+ if response is not None:
+ results.append(self.format_item(response))
+
+ return results
+
+ def list_by_server(self):
+ '''
+ Gets facts of the specified MariaDB Configuration.
+
+ :return: deserialized MariaDB Configurationinstance state dictionary
+ '''
+ response = None
+ results = []
+ try:
+ response = self.mgmt_client.configurations.list_by_server(resource_group_name=self.resource_group,
+ server_name=self.server_name)
+ self.log("Response : {0}".format(response))
+ except CloudError as e:
+ self.log('Could not get facts for Configurations.')
+
+ if response is not None:
+ for item in response:
+ results.append(self.format_item(item))
+
+ return results
+
+ def format_item(self, item):
+ d = item.as_dict()
+ d = {
+ 'resource_group': self.resource_group,
+ 'server_name': self.server_name,
+ 'id': d['id'],
+ 'name': d['name'],
+ 'value': d['value'],
+ 'description': d['description'],
+ 'source': d['source']
+ }
+ return d
+
+
+def main():
+ AzureRMMariaDbConfigurationInfo()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py
new file mode 100644
index 00000000..8492b968
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py
@@ -0,0 +1,304 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com>
+# Copyright (c) 2019 Matti Ranta, (@techknowlogick)
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_mariadbdatabase
+version_added: "2.8"
+short_description: Manage MariaDB Database instance
+description:
+ - Create, update and delete instance of MariaDB Database.
+
+options:
+ resource_group:
+ description:
+ - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal.
+ required: True
+ server_name:
+ description:
+ - The name of the server.
+ required: True
+ name:
+ description:
+ - The name of the database.
+ required: True
+ charset:
+ description:
+ - The charset of the database. Check MariaDB documentation for possible values.
+ - This is only set on creation, use I(force_update) to recreate a database if the values don't match.
+ collation:
+ description:
+ - The collation of the database. Check MariaDB documentation for possible values.
+ - This is only set on creation, use I(force_update) to recreate a database if the values don't match.
+ force_update:
+ description:
+ - When set to C(true), will delete and recreate the existing MariaDB database if any of the properties don't match what is set.
+ - When set to C(false), no change will occur to the database even if any of the properties do not match.
+ type: bool
+ default: 'no'
+ state:
+ description:
+ - Assert the state of the MariaDB Database. Use C(present) to create or update a database and C(absent) to delete it.
+ default: present
+ choices:
+ - absent
+ - present
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Zim Kalinowski (@zikalino)
+ - Matti Ranta (@techknowlogick)
+
+'''
+
+EXAMPLES = '''
+ - name: Create (or update) MariaDB Database
+ azure_rm_mariadbdatabase:
+ resource_group: myResourceGroup
+ server_name: testserver
+ name: db1
+'''
+
+RETURN = '''
+id:
+ description:
+ - Resource ID.
+ returned: always
+ type: str
+ sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testserver/databases/db1
+name:
+ description:
+ - Resource name.
+ returned: always
+ type: str
+ sample: db1
+'''
+
+import time
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from azure.mgmt.rdbms.mariadb import MariaDBManagementClient
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.polling import LROPoller
+ from msrest.serialization import Model
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class Actions:
+ NoAction, Create, Update, Delete = range(4)
+
+
+class AzureRMMariaDbDatabase(AzureRMModuleBase):
+ """Configuration class for an Azure RM MariaDB Database resource"""
+
+ def __init__(self):
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ server_name=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str',
+ required=True
+ ),
+ charset=dict(
+ type='str'
+ ),
+ collation=dict(
+ type='str'
+ ),
+ force_update=dict(
+ type='bool',
+ default=False
+ ),
+ state=dict(
+ type='str',
+ default='present',
+ choices=['present', 'absent']
+ )
+ )
+
+ self.resource_group = None
+ self.server_name = None
+ self.name = None
+ self.force_update = None
+ self.parameters = dict()
+
+ self.results = dict(changed=False)
+ self.mgmt_client = None
+ self.state = None
+ self.to_do = Actions.NoAction
+
+ super(AzureRMMariaDbDatabase, self).__init__(derived_arg_spec=self.module_arg_spec,
+ supports_check_mode=True,
+ supports_tags=False)
+
+ def exec_module(self, **kwargs):
+ """Main module execution method"""
+
+ for key in list(self.module_arg_spec.keys()):
+ if hasattr(self, key):
+ setattr(self, key, kwargs[key])
+ elif kwargs[key] is not None:
+ if key == "charset":
+ self.parameters["charset"] = kwargs[key]
+ elif key == "collation":
+ self.parameters["collation"] = kwargs[key]
+
+ old_response = None
+ response = None
+
+ self.mgmt_client = self.get_mgmt_svc_client(MariaDBManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+
+ resource_group = self.get_resource_group(self.resource_group)
+
+ old_response = self.get_mariadbdatabase()
+
+ if not old_response:
+ self.log("MariaDB Database instance doesn't exist")
+ if self.state == 'absent':
+ self.log("Old instance didn't exist")
+ else:
+ self.to_do = Actions.Create
+ else:
+ self.log("MariaDB Database instance already exists")
+ if self.state == 'absent':
+ self.to_do = Actions.Delete
+ elif self.state == 'present':
+ self.log("Need to check if MariaDB Database instance has to be deleted or may be updated")
+ if ('collation' in self.parameters) and (self.parameters['collation'] != old_response['collation']):
+ self.to_do = Actions.Update
+ if ('charset' in self.parameters) and (self.parameters['charset'] != old_response['charset']):
+ self.to_do = Actions.Update
+ if self.to_do == Actions.Update:
+ if self.force_update:
+ if not self.check_mode:
+ self.delete_mariadbdatabase()
+ else:
+ self.fail("Database properties cannot be updated without setting 'force_update' option")
+ self.to_do = Actions.NoAction
+
+ if (self.to_do == Actions.Create) or (self.to_do == Actions.Update):
+ self.log("Need to Create / Update the MariaDB Database instance")
+
+ if self.check_mode:
+ self.results['changed'] = True
+ return self.results
+
+ response = self.create_update_mariadbdatabase()
+ self.results['changed'] = True
+ self.log("Creation / Update done")
+ elif self.to_do == Actions.Delete:
+ self.log("MariaDB Database instance deleted")
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ self.delete_mariadbdatabase()
+ # make sure instance is actually deleted, for some Azure resources, instance is hanging around
+ # for some time after deletion -- this should be really fixed in Azure
+ while self.get_mariadbdatabase():
+ time.sleep(20)
+ else:
+ self.log("MariaDB Database instance unchanged")
+ self.results['changed'] = False
+ response = old_response
+
+ if response:
+ self.results["id"] = response["id"]
+ self.results["name"] = response["name"]
+
+ return self.results
+
+ def create_update_mariadbdatabase(self):
+ '''
+ Creates or updates MariaDB Database with the specified configuration.
+
+ :return: deserialized MariaDB Database instance state dictionary
+ '''
+ self.log("Creating / Updating the MariaDB Database instance {0}".format(self.name))
+
+ try:
+ response = self.mgmt_client.databases.create_or_update(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ database_name=self.name,
+ parameters=self.parameters)
+ if isinstance(response, LROPoller):
+ response = self.get_poller_result(response)
+
+ except CloudError as exc:
+ self.log('Error attempting to create the MariaDB Database instance.')
+ self.fail("Error creating the MariaDB Database instance: {0}".format(str(exc)))
+ return response.as_dict()
+
+ def delete_mariadbdatabase(self):
+ '''
+ Deletes specified MariaDB Database instance in the specified subscription and resource group.
+
+ :return: True
+ '''
+ self.log("Deleting the MariaDB Database instance {0}".format(self.name))
+ try:
+ response = self.mgmt_client.databases.delete(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ database_name=self.name)
+ except CloudError as e:
+ self.log('Error attempting to delete the MariaDB Database instance.')
+ self.fail("Error deleting the MariaDB Database instance: {0}".format(str(e)))
+
+ return True
+
+ def get_mariadbdatabase(self):
+ '''
+ Gets the properties of the specified MariaDB Database.
+
+ :return: deserialized MariaDB Database instance state dictionary
+ '''
+ self.log("Checking if the MariaDB Database instance {0} is present".format(self.name))
+ found = False
+ try:
+ response = self.mgmt_client.databases.get(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ database_name=self.name)
+ found = True
+ self.log("Response : {0}".format(response))
+ self.log("MariaDB Database instance : {0} found".format(response.name))
+ except CloudError as e:
+ self.log('Did not find the MariaDB Database instance.')
+ if found is True:
+ return response.as_dict()
+
+ return False
+
+
+def main():
+ """Main execution"""
+ AzureRMMariaDbDatabase()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py
new file mode 100644
index 00000000..e9c99c14
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py
@@ -0,0 +1,212 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com>
+# Copyright (c) 2019 Matti Ranta, (@techknowlogick)
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_mariadbdatabase_info
+version_added: "2.9"
+short_description: Get Azure MariaDB Database facts
+description:
+ - Get facts of MariaDB Database.
+
+options:
+ resource_group:
+ description:
+ - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal.
+ required: True
+ type: str
+ server_name:
+ description:
+ - The name of the server.
+ required: True
+ type: str
+ name:
+ description:
+ - The name of the database.
+ type: str
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Zim Kalinowski (@zikalino)
+ - Matti Ranta (@techknowlogick)
+
+'''
+
+EXAMPLES = '''
+ - name: Get instance of MariaDB Database
+ azure_rm_mariadbdatabase_info:
+ resource_group: myResourceGroup
+ server_name: server_name
+ name: database_name
+
+ - name: List instances of MariaDB Database
+ azure_rm_mariadbdatabase_info:
+ resource_group: myResourceGroup
+ server_name: server_name
+'''
+
+RETURN = '''
+databases:
+ description:
+ - A list of dictionaries containing facts for MariaDB Databases.
+ returned: always
+ type: complex
+ contains:
+ id:
+ description:
+ - Resource ID.
+ returned: always
+ type: str
+ sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testser
+ ver/databases/db1"
+ resource_group:
+ description:
+ - Resource group name.
+ returned: always
+ type: str
+ sample: testrg
+ server_name:
+ description:
+ - Server name.
+ returned: always
+ type: str
+ sample: testserver
+ name:
+ description:
+ - Resource name.
+ returned: always
+ type: str
+ sample: db1
+ charset:
+ description:
+ - The charset of the database.
+ returned: always
+ type: str
+ sample: UTF8
+ collation:
+ description:
+ - The collation of the database.
+ returned: always
+ type: str
+ sample: English_United States.1252
+'''
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from azure.mgmt.rdbms.mariadb import MariaDBManagementClient
+ from msrest.serialization import Model
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class AzureRMMariaDbDatabaseInfo(AzureRMModuleBase):
+ def __init__(self):
+ # define user inputs into argument
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ server_name=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str'
+ )
+ )
+ # store the results of the module operation
+ self.results = dict(
+ changed=False
+ )
+ self.resource_group = None
+ self.server_name = None
+ self.name = None
+ super(AzureRMMariaDbDatabaseInfo, self).__init__(self.module_arg_spec, supports_tags=False)
+
+ def exec_module(self, **kwargs):
+ is_old_facts = self.module._name == 'azure_rm_mariadbdatabase_facts'
+ if is_old_facts:
+ self.module.deprecate("The 'azure_rm_mariadbdatabase_facts' module has been renamed to 'azure_rm_mariadbdatabase_info'",
+ version='2.13', collection_name='ansible.builtin')
+
+ for key in self.module_arg_spec:
+ setattr(self, key, kwargs[key])
+
+ if (self.resource_group is not None and
+ self.server_name is not None and
+ self.name is not None):
+ self.results['databases'] = self.get()
+ elif (self.resource_group is not None and
+ self.server_name is not None):
+ self.results['databases'] = self.list_by_server()
+ return self.results
+
+ def get(self):
+ response = None
+ results = []
+ try:
+ response = self.mariadb_client.databases.get(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ database_name=self.name)
+ self.log("Response : {0}".format(response))
+ except CloudError as e:
+ self.log('Could not get facts for Databases.')
+
+ if response is not None:
+ results.append(self.format_item(response))
+
+ return results
+
+ def list_by_server(self):
+ response = None
+ results = []
+ try:
+ response = self.mariadb_client.databases.list_by_server(resource_group_name=self.resource_group,
+ server_name=self.server_name)
+ self.log("Response : {0}".format(response))
+ except CloudError as e:
+ self.fail("Error listing for server {0} - {1}".format(self.server_name, str(e)))
+
+ if response is not None:
+ for item in response:
+ results.append(self.format_item(item))
+
+ return results
+
+ def format_item(self, item):
+ d = item.as_dict()
+ d = {
+ 'resource_group': self.resource_group,
+ 'server_name': self.server_name,
+ 'name': d['name'],
+ 'charset': d['charset'],
+ 'collation': d['collation']
+ }
+ return d
+
+
+def main():
+ AzureRMMariaDbDatabaseInfo()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py
new file mode 100644
index 00000000..1fc8c5e7
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py
@@ -0,0 +1,277 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com>
+# Copyright (c) 2019 Matti Ranta, (@techknowlogick)
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_mariadbfirewallrule
+version_added: "2.8"
+short_description: Manage MariaDB firewall rule instance
+description:
+ - Create, update and delete instance of MariaDB firewall rule.
+
+options:
+ resource_group:
+ description:
+ - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal.
+ required: True
+ server_name:
+ description:
+ - The name of the server.
+ required: True
+ name:
+ description:
+ - The name of the MariaDB firewall rule.
+ required: True
+ start_ip_address:
+ description:
+ - The start IP address of the MariaDB firewall rule. Must be IPv4 format.
+ end_ip_address:
+ description:
+ - The end IP address of the MariaDB firewall rule. Must be IPv4 format.
+ state:
+ description:
+ - Assert the state of the MariaDB firewall rule. Use C(present) to create or update a rule and C(absent) to ensure it is not present.
+ default: present
+ choices:
+ - absent
+ - present
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Zim Kalinowski (@zikalino)
+ - Matti Ranta (@techknowlogick)
+
+'''
+
+EXAMPLES = '''
+ - name: Create (or update) MariaDB firewall rule
+ azure_rm_mariadbfirewallrule:
+ resource_group: myResourceGroup
+ server_name: testserver
+ name: rule1
+ start_ip_address: 10.0.0.17
+ end_ip_address: 10.0.0.20
+'''
+
+RETURN = '''
+id:
+ description:
+ - Resource ID.
+ returned: always
+ type: str
+ sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testserver/fire
+ wallRules/rule1"
+'''
+
+import time
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.polling import LROPoller
+ from azure.mgmt.rdbms.mariadb import MariaDBManagementClient
+ from msrest.serialization import Model
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class Actions:
+ NoAction, Create, Update, Delete = range(4)
+
+
+class AzureRMMariaDbFirewallRule(AzureRMModuleBase):
+ """Configuration class for an Azure RM MariaDB firewall rule resource"""
+
+ def __init__(self):
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ server_name=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str',
+ required=True
+ ),
+ start_ip_address=dict(
+ type='str'
+ ),
+ end_ip_address=dict(
+ type='str'
+ ),
+ state=dict(
+ type='str',
+ default='present',
+ choices=['present', 'absent']
+ )
+ )
+
+ self.resource_group = None
+ self.server_name = None
+ self.name = None
+ self.start_ip_address = None
+ self.end_ip_address = None
+
+ self.results = dict(changed=False)
+ self.state = None
+ self.to_do = Actions.NoAction
+
+ super(AzureRMMariaDbFirewallRule, self).__init__(derived_arg_spec=self.module_arg_spec,
+ supports_check_mode=True,
+ supports_tags=False)
+
+ def exec_module(self, **kwargs):
+ """Main module execution method"""
+
+ for key in list(self.module_arg_spec.keys()):
+ if hasattr(self, key):
+ setattr(self, key, kwargs[key])
+
+ old_response = None
+ response = None
+
+ resource_group = self.get_resource_group(self.resource_group)
+
+ old_response = self.get_firewallrule()
+
+ if not old_response:
+ self.log("MariaDB firewall rule instance doesn't exist")
+ if self.state == 'absent':
+ self.log("Old instance didn't exist")
+ else:
+ self.to_do = Actions.Create
+ else:
+ self.log("MariaDB firewall rule instance already exists")
+ if self.state == 'absent':
+ self.to_do = Actions.Delete
+ elif self.state == 'present':
+ self.log("Need to check if MariaDB firewall rule instance has to be deleted or may be updated")
+ if (self.start_ip_address is not None) and (self.start_ip_address != old_response['start_ip_address']):
+ self.to_do = Actions.Update
+ if (self.end_ip_address is not None) and (self.end_ip_address != old_response['end_ip_address']):
+ self.to_do = Actions.Update
+
+ if (self.to_do == Actions.Create) or (self.to_do == Actions.Update):
+ self.log("Need to Create / Update the MariaDB firewall rule instance")
+
+ if self.check_mode:
+ self.results['changed'] = True
+ return self.results
+
+ response = self.create_update_firewallrule()
+
+ if not old_response:
+ self.results['changed'] = True
+ else:
+ self.results['changed'] = old_response.__ne__(response)
+ self.log("Creation / Update done")
+ elif self.to_do == Actions.Delete:
+ self.log("MariaDB firewall rule instance deleted")
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ self.delete_firewallrule()
+ # make sure instance is actually deleted, for some Azure resources, instance is hanging around
+ # for some time after deletion -- this should be really fixed in Azure
+ while self.get_firewallrule():
+ time.sleep(20)
+ else:
+ self.log("MariaDB firewall rule instance unchanged")
+ self.results['changed'] = False
+ response = old_response
+
+ if response:
+ self.results["id"] = response["id"]
+
+ return self.results
+
+ def create_update_firewallrule(self):
+ '''
+ Creates or updates MariaDB firewall rule with the specified configuration.
+
+ :return: deserialized MariaDB firewall rule instance state dictionary
+ '''
+ self.log("Creating / Updating the MariaDB firewall rule instance {0}".format(self.name))
+
+ try:
+ response = self.mariadb_client.firewall_rules.create_or_update(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ firewall_rule_name=self.name,
+ start_ip_address=self.start_ip_address,
+ end_ip_address=self.end_ip_address)
+ if isinstance(response, LROPoller):
+ response = self.get_poller_result(response)
+
+ except CloudError as exc:
+ self.log('Error attempting to create the MariaDB firewall rule instance.')
+ self.fail("Error creating the MariaDB firewall rule instance: {0}".format(str(exc)))
+ return response.as_dict()
+
+ def delete_firewallrule(self):
+ '''
+ Deletes specified MariaDB firewall rule instance in the specified subscription and resource group.
+
+ :return: True
+ '''
+ self.log("Deleting the MariaDB firewall rule instance {0}".format(self.name))
+ try:
+ response = self.mariadb_client.firewall_rules.delete(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ firewall_rule_name=self.name)
+ except CloudError as e:
+ self.log('Error attempting to delete the MariaDB firewall rule instance.')
+ self.fail("Error deleting the MariaDB firewall rule instance: {0}".format(str(e)))
+
+ return True
+
+ def get_firewallrule(self):
+ '''
+ Gets the properties of the specified MariaDB firewall rule.
+
+ :return: deserialized MariaDB firewall rule instance state dictionary
+ '''
+ self.log("Checking if the MariaDB firewall rule instance {0} is present".format(self.name))
+ found = False
+ try:
+ response = self.mariadb_client.firewall_rules.get(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ firewall_rule_name=self.name)
+ found = True
+ self.log("Response : {0}".format(response))
+ self.log("MariaDB firewall rule instance : {0} found".format(response.name))
+ except CloudError as e:
+ self.log('Did not find the MariaDB firewall rule instance.')
+ if found is True:
+ return response.as_dict()
+
+ return False
+
+
+def main():
+ """Main execution"""
+ AzureRMMariaDbFirewallRule()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py
new file mode 100644
index 00000000..ef71be8d
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py
@@ -0,0 +1,208 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com>
+# Copyright (c) 2019 Matti Ranta, (@techknowlogick)
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_mariadbfirewallrule_info
+version_added: "2.9"
+short_description: Get Azure MariaDB Firewall Rule facts
+description:
+ - Get facts of Azure MariaDB Firewall Rule.
+
+options:
+ resource_group:
+ description:
+ - The name of the resource group.
+ required: True
+ type: str
+ server_name:
+ description:
+ - The name of the server.
+ required: True
+ type: str
+ name:
+ description:
+ - The name of the server firewall rule.
+ type: str
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Zim Kalinowski (@zikalino)
+ - Matti Ranta (@techknowlogick)
+
+'''
+
+EXAMPLES = '''
+ - name: Get instance of MariaDB Firewall Rule
+ azure_rm_mariadbfirewallrule_info:
+ resource_group: myResourceGroup
+ server_name: server_name
+ name: firewall_rule_name
+
+ - name: List instances of MariaDB Firewall Rule
+ azure_rm_mariadbfirewallrule_info:
+ resource_group: myResourceGroup
+ server_name: server_name
+'''
+
+RETURN = '''
+rules:
+ description:
+ - A list of dictionaries containing facts for MariaDB Firewall Rule.
+ returned: always
+ type: complex
+ contains:
+ id:
+ description:
+ - Resource ID.
+ returned: always
+ type: str
+ sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/TestGroup/providers/Microsoft.DBforMariaDB/servers/testserver/fire
+ wallRules/rule1"
+ server_name:
+ description:
+ - The name of the server.
+ returned: always
+ type: str
+ sample: testserver
+ name:
+ description:
+ - Resource name.
+ returned: always
+ type: str
+ sample: rule1
+ start_ip_address:
+ description:
+ - The start IP address of the MariaDB firewall rule.
+ returned: always
+ type: str
+ sample: 10.0.0.16
+ end_ip_address:
+ description:
+ - The end IP address of the MariaDB firewall rule.
+ returned: always
+ type: str
+ sample: 10.0.0.18
+'''
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrestazure.azure_operation import AzureOperationPoller
+ from azure.mgmt.rdbms.mariadb import MariaDBManagementClient
+ from msrest.serialization import Model
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class AzureRMMariaDbFirewallRuleInfo(AzureRMModuleBase):
+ def __init__(self):
+ # define user inputs into argument
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ server_name=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str'
+ )
+ )
+ # store the results of the module operation
+ self.results = dict(
+ changed=False
+ )
+ self.mgmt_client = None
+ self.resource_group = None
+ self.server_name = None
+ self.name = None
+ super(AzureRMMariaDbFirewallRuleInfo, self).__init__(self.module_arg_spec, supports_tags=False)
+
+ def exec_module(self, **kwargs):
+ is_old_facts = self.module._name == 'azure_rm_mariadbfirewallrule_facts'
+ if is_old_facts:
+ self.module.deprecate("The 'azure_rm_mariadbfirewallrule_facts' module has been renamed to 'azure_rm_mariadbfirewallrule_info'",
+ version='2.13', collection_name='ansible.builtin')
+
+ for key in self.module_arg_spec:
+ setattr(self, key, kwargs[key])
+ self.mgmt_client = self.get_mgmt_svc_client(MariaDBManagementClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+
+ if (self.name is not None):
+ self.results['rules'] = self.get()
+ else:
+ self.results['rules'] = self.list_by_server()
+ return self.results
+
+ def get(self):
+ response = None
+ results = []
+ try:
+ response = self.mgmt_client.firewall_rules.get(resource_group_name=self.resource_group,
+ server_name=self.server_name,
+ firewall_rule_name=self.name)
+ self.log("Response : {0}".format(response))
+ except CloudError as e:
+ self.log('Could not get facts for FirewallRules.')
+
+ if response is not None:
+ results.append(self.format_item(response))
+
+ return results
+
+ def list_by_server(self):
+ response = None
+ results = []
+ try:
+ response = self.mgmt_client.firewall_rules.list_by_server(resource_group_name=self.resource_group,
+ server_name=self.server_name)
+ self.log("Response : {0}".format(response))
+ except CloudError as e:
+ self.log('Could not get facts for FirewallRules.')
+
+ if response is not None:
+ for item in response:
+ results.append(self.format_item(item))
+
+ return results
+
+ def format_item(self, item):
+ d = item.as_dict()
+ d = {
+ 'resource_group': self.resource_group,
+ 'id': d['id'],
+ 'server_name': self.server_name,
+ 'name': d['name'],
+ 'start_ip_address': d['start_ip_address'],
+ 'end_ip_address': d['end_ip_address']
+ }
+ return d
+
+
+def main():
+ AzureRMMariaDbFirewallRuleInfo()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbserver.py b/test/support/integration/plugins/modules/azure_rm_mariadbserver.py
new file mode 100644
index 00000000..30a29988
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_mariadbserver.py
@@ -0,0 +1,388 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com>
+# Copyright (c) 2019 Matti Ranta, (@techknowlogick)
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_mariadbserver
+version_added: "2.8"
+short_description: Manage MariaDB Server instance
+description:
+ - Create, update and delete instance of MariaDB Server.
+
+options:
+ resource_group:
+ description:
+ - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal.
+ required: True
+ name:
+ description:
+ - The name of the server.
+ required: True
+ sku:
+ description:
+ - The SKU (pricing tier) of the server.
+ suboptions:
+ name:
+ description:
+ - The name of the SKU, typically, tier + family + cores, for example C(B_Gen4_1), C(GP_Gen5_8).
+ tier:
+ description:
+ - The tier of the particular SKU, for example C(Basic).
+ choices:
+ - basic
+ - standard
+ capacity:
+ description:
+ - The scale up/out capacity, representing server's compute units.
+ type: int
+ size:
+ description:
+ - The size code, to be interpreted by resource as appropriate.
+ location:
+ description:
+ - Resource location. If not set, location from the resource group will be used as default.
+ storage_mb:
+ description:
+ - The maximum storage allowed for a server.
+ type: int
+ version:
+ description:
+ - Server version.
+ choices:
+ - 10.2
+ enforce_ssl:
+ description:
+ - Enable SSL enforcement.
+ type: bool
+ default: False
+ admin_username:
+ description:
+ - The administrator's login name of a server. Can only be specified when the server is being created (and is required for creation).
+ admin_password:
+ description:
+ - The password of the administrator login.
+ create_mode:
+ description:
+ - Create mode of SQL Server.
+ default: Default
+ state:
+ description:
+ - Assert the state of the MariaDB Server. Use C(present) to create or update a server and C(absent) to delete it.
+ default: present
+ choices:
+ - absent
+ - present
+
+extends_documentation_fragment:
+ - azure
+ - azure_tags
+
+author:
+ - Zim Kalinowski (@zikalino)
+ - Matti Ranta (@techknowlogick)
+
+'''
+
+EXAMPLES = '''
+ - name: Create (or update) MariaDB Server
+ azure_rm_mariadbserver:
+ resource_group: myResourceGroup
+ name: testserver
+ sku:
+ name: B_Gen5_1
+ tier: Basic
+ location: eastus
+ storage_mb: 1024
+ enforce_ssl: True
+ version: 10.2
+ admin_username: cloudsa
+ admin_password: password
+'''
+
+RETURN = '''
+id:
+ description:
+ - Resource ID.
+ returned: always
+ type: str
+ sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/mariadbsrv1b6dd89593
+version:
+ description:
+ - Server version. Possible values include C(10.2).
+ returned: always
+ type: str
+ sample: 10.2
+state:
+ description:
+ - A state of a server that is visible to user. Possible values include C(Ready), C(Dropping), C(Disabled).
+ returned: always
+ type: str
+ sample: Ready
+fully_qualified_domain_name:
+ description:
+ - The fully qualified domain name of a server.
+ returned: always
+ type: str
+ sample: mariadbsrv1b6dd89593.mariadb.database.azure.com
+'''
+
+import time
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from azure.mgmt.rdbms.mariadb import MariaDBManagementClient
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.polling import LROPoller
+ from msrest.serialization import Model
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class Actions:
+ NoAction, Create, Update, Delete = range(4)
+
+
+class AzureRMMariaDbServers(AzureRMModuleBase):
+ """Configuration class for an Azure RM MariaDB Server resource"""
+
+ def __init__(self):
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str',
+ required=True
+ ),
+ sku=dict(
+ type='dict'
+ ),
+ location=dict(
+ type='str'
+ ),
+ storage_mb=dict(
+ type='int'
+ ),
+ version=dict(
+ type='str',
+ choices=['10.2']
+ ),
+ enforce_ssl=dict(
+ type='bool',
+ default=False
+ ),
+ create_mode=dict(
+ type='str',
+ default='Default'
+ ),
+ admin_username=dict(
+ type='str'
+ ),
+ admin_password=dict(
+ type='str',
+ no_log=True
+ ),
+ state=dict(
+ type='str',
+ default='present',
+ choices=['present', 'absent']
+ )
+ )
+
+ self.resource_group = None
+ self.name = None
+ self.parameters = dict()
+ self.tags = None
+
+ self.results = dict(changed=False)
+ self.state = None
+ self.to_do = Actions.NoAction
+
+ super(AzureRMMariaDbServers, self).__init__(derived_arg_spec=self.module_arg_spec,
+ supports_check_mode=True,
+ supports_tags=True)
+
+ def exec_module(self, **kwargs):
+ """Main module execution method"""
+
+ for key in list(self.module_arg_spec.keys()) + ['tags']:
+ if hasattr(self, key):
+ setattr(self, key, kwargs[key])
+ elif kwargs[key] is not None:
+ if key == "sku":
+ ev = kwargs[key]
+ if 'tier' in ev:
+ if ev['tier'] == 'basic':
+ ev['tier'] = 'Basic'
+ elif ev['tier'] == 'standard':
+ ev['tier'] = 'Standard'
+ self.parameters["sku"] = ev
+ elif key == "location":
+ self.parameters["location"] = kwargs[key]
+ elif key == "storage_mb":
+ self.parameters.setdefault("properties", {}).setdefault("storage_profile", {})["storage_mb"] = kwargs[key]
+ elif key == "version":
+ self.parameters.setdefault("properties", {})["version"] = kwargs[key]
+ elif key == "enforce_ssl":
+ self.parameters.setdefault("properties", {})["ssl_enforcement"] = 'Enabled' if kwargs[key] else 'Disabled'
+ elif key == "create_mode":
+ self.parameters.setdefault("properties", {})["create_mode"] = kwargs[key]
+ elif key == "admin_username":
+ self.parameters.setdefault("properties", {})["administrator_login"] = kwargs[key]
+ elif key == "admin_password":
+ self.parameters.setdefault("properties", {})["administrator_login_password"] = kwargs[key]
+
+ old_response = None
+ response = None
+
+ resource_group = self.get_resource_group(self.resource_group)
+
+ if "location" not in self.parameters:
+ self.parameters["location"] = resource_group.location
+
+ old_response = self.get_mariadbserver()
+
+ if not old_response:
+ self.log("MariaDB Server instance doesn't exist")
+ if self.state == 'absent':
+ self.log("Old instance didn't exist")
+ else:
+ self.to_do = Actions.Create
+ else:
+ self.log("MariaDB Server instance already exists")
+ if self.state == 'absent':
+ self.to_do = Actions.Delete
+ elif self.state == 'present':
+ self.log("Need to check if MariaDB Server instance has to be deleted or may be updated")
+ update_tags, newtags = self.update_tags(old_response.get('tags', {}))
+ if update_tags:
+ self.tags = newtags
+ self.to_do = Actions.Update
+
+ if (self.to_do == Actions.Create) or (self.to_do == Actions.Update):
+ self.log("Need to Create / Update the MariaDB Server instance")
+
+ if self.check_mode:
+ self.results['changed'] = True
+ return self.results
+
+ response = self.create_update_mariadbserver()
+
+ if not old_response:
+ self.results['changed'] = True
+ else:
+ self.results['changed'] = old_response.__ne__(response)
+ self.log("Creation / Update done")
+ elif self.to_do == Actions.Delete:
+ self.log("MariaDB Server instance deleted")
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ self.delete_mariadbserver()
+ # make sure instance is actually deleted, for some Azure resources, instance is hanging around
+ # for some time after deletion -- this should be really fixed in Azure
+ while self.get_mariadbserver():
+ time.sleep(20)
+ else:
+ self.log("MariaDB Server instance unchanged")
+ self.results['changed'] = False
+ response = old_response
+
+ if response:
+ self.results["id"] = response["id"]
+ self.results["version"] = response["version"]
+ self.results["state"] = response["user_visible_state"]
+ self.results["fully_qualified_domain_name"] = response["fully_qualified_domain_name"]
+
+ return self.results
+
+ def create_update_mariadbserver(self):
+ '''
+ Creates or updates MariaDB Server with the specified configuration.
+
+ :return: deserialized MariaDB Server instance state dictionary
+ '''
+ self.log("Creating / Updating the MariaDB Server instance {0}".format(self.name))
+
+ try:
+ self.parameters['tags'] = self.tags
+ if self.to_do == Actions.Create:
+ response = self.mariadb_client.servers.create(resource_group_name=self.resource_group,
+ server_name=self.name,
+ parameters=self.parameters)
+ else:
+ # structure of parameters for update must be changed
+ self.parameters.update(self.parameters.pop("properties", {}))
+ response = self.mariadb_client.servers.update(resource_group_name=self.resource_group,
+ server_name=self.name,
+ parameters=self.parameters)
+ if isinstance(response, LROPoller):
+ response = self.get_poller_result(response)
+
+ except CloudError as exc:
+ self.log('Error attempting to create the MariaDB Server instance.')
+ self.fail("Error creating the MariaDB Server instance: {0}".format(str(exc)))
+ return response.as_dict()
+
+ def delete_mariadbserver(self):
+ '''
+ Deletes specified MariaDB Server instance in the specified subscription and resource group.
+
+ :return: True
+ '''
+ self.log("Deleting the MariaDB Server instance {0}".format(self.name))
+ try:
+ response = self.mariadb_client.servers.delete(resource_group_name=self.resource_group,
+ server_name=self.name)
+ except CloudError as e:
+ self.log('Error attempting to delete the MariaDB Server instance.')
+ self.fail("Error deleting the MariaDB Server instance: {0}".format(str(e)))
+
+ return True
+
+ def get_mariadbserver(self):
+ '''
+ Gets the properties of the specified MariaDB Server.
+
+ :return: deserialized MariaDB Server instance state dictionary
+ '''
+ self.log("Checking if the MariaDB Server instance {0} is present".format(self.name))
+ found = False
+ try:
+ response = self.mariadb_client.servers.get(resource_group_name=self.resource_group,
+ server_name=self.name)
+ found = True
+ self.log("Response : {0}".format(response))
+ self.log("MariaDB Server instance : {0} found".format(response.name))
+ except CloudError as e:
+ self.log('Did not find the MariaDB Server instance.')
+ if found is True:
+ return response.as_dict()
+
+ return False
+
+
+def main():
+ """Main execution"""
+ AzureRMMariaDbServers()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py
new file mode 100644
index 00000000..464aa4d8
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py
@@ -0,0 +1,265 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com>
+# Copyright (c) 2019 Matti Ranta, (@techknowlogick)
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_mariadbserver_info
+version_added: "2.9"
+short_description: Get Azure MariaDB Server facts
+description:
+ - Get facts of MariaDB Server.
+
+options:
+ resource_group:
+ description:
+ - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal.
+ required: True
+ type: str
+ name:
+ description:
+ - The name of the server.
+ type: str
+ tags:
+ description:
+ - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'.
+ type: list
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Zim Kalinowski (@zikalino)
+ - Matti Ranta (@techknowlogick)
+
+'''
+
+EXAMPLES = '''
+ - name: Get instance of MariaDB Server
+ azure_rm_mariadbserver_info:
+ resource_group: myResourceGroup
+ name: server_name
+
+ - name: List instances of MariaDB Server
+ azure_rm_mariadbserver_info:
+ resource_group: myResourceGroup
+'''
+
+RETURN = '''
+servers:
+ description:
+ - A list of dictionaries containing facts for MariaDB servers.
+ returned: always
+ type: complex
+ contains:
+ id:
+ description:
+ - Resource ID.
+ returned: always
+ type: str
+ sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/myabdud1223
+ resource_group:
+ description:
+ - Resource group name.
+ returned: always
+ type: str
+ sample: myResourceGroup
+ name:
+ description:
+ - Resource name.
+ returned: always
+ type: str
+ sample: myabdud1223
+ location:
+ description:
+ - The location the resource resides in.
+ returned: always
+ type: str
+ sample: eastus
+ sku:
+ description:
+ - The SKU of the server.
+ returned: always
+ type: complex
+ contains:
+ name:
+ description:
+ - The name of the SKU.
+ returned: always
+ type: str
+ sample: GP_Gen4_2
+ tier:
+ description:
+ - The tier of the particular SKU.
+ returned: always
+ type: str
+ sample: GeneralPurpose
+ capacity:
+ description:
+ - The scale capacity.
+ returned: always
+ type: int
+ sample: 2
+ storage_mb:
+ description:
+ - The maximum storage allowed for a server.
+ returned: always
+ type: int
+ sample: 128000
+ enforce_ssl:
+ description:
+ - Enable SSL enforcement.
+ returned: always
+ type: bool
+ sample: False
+ admin_username:
+ description:
+ - The administrator's login name of a server.
+ returned: always
+ type: str
+ sample: serveradmin
+ version:
+ description:
+ - Server version.
+ returned: always
+ type: str
+ sample: "9.6"
+ user_visible_state:
+ description:
+ - A state of a server that is visible to user.
+ returned: always
+ type: str
+ sample: Ready
+ fully_qualified_domain_name:
+ description:
+ - The fully qualified domain name of a server.
+ returned: always
+ type: str
+ sample: myabdud1223.mys.database.azure.com
+ tags:
+ description:
+ - Tags assigned to the resource. Dictionary of string:string pairs.
+ type: dict
+ sample: { tag1: abc }
+'''
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from azure.mgmt.rdbms.mariadb import MariaDBManagementClient
+ from msrest.serialization import Model
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class AzureRMMariaDbServerInfo(AzureRMModuleBase):
+ def __init__(self):
+ # define user inputs into argument
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str'
+ ),
+ tags=dict(
+ type='list'
+ )
+ )
+ # store the results of the module operation
+ self.results = dict(
+ changed=False
+ )
+ self.resource_group = None
+ self.name = None
+ self.tags = None
+ super(AzureRMMariaDbServerInfo, self).__init__(self.module_arg_spec, supports_tags=False)
+
+ def exec_module(self, **kwargs):
+ is_old_facts = self.module._name == 'azure_rm_mariadbserver_facts'
+ if is_old_facts:
+ self.module.deprecate("The 'azure_rm_mariadbserver_facts' module has been renamed to 'azure_rm_mariadbserver_info'",
+ version='2.13', collection_name='ansible.builtin')
+
+ for key in self.module_arg_spec:
+ setattr(self, key, kwargs[key])
+
+ if (self.resource_group is not None and
+ self.name is not None):
+ self.results['servers'] = self.get()
+ elif (self.resource_group is not None):
+ self.results['servers'] = self.list_by_resource_group()
+ return self.results
+
+ def get(self):
+ response = None
+ results = []
+ try:
+ response = self.mariadb_client.servers.get(resource_group_name=self.resource_group,
+ server_name=self.name)
+ self.log("Response : {0}".format(response))
+ except CloudError as e:
+ self.log('Could not get facts for MariaDB Server.')
+
+ if response and self.has_tags(response.tags, self.tags):
+ results.append(self.format_item(response))
+
+ return results
+
+ def list_by_resource_group(self):
+ response = None
+ results = []
+ try:
+ response = self.mariadb_client.servers.list_by_resource_group(resource_group_name=self.resource_group)
+ self.log("Response : {0}".format(response))
+ except CloudError as e:
+ self.log('Could not get facts for MariaDB Servers.')
+
+ if response is not None:
+ for item in response:
+ if self.has_tags(item.tags, self.tags):
+ results.append(self.format_item(item))
+
+ return results
+
+ def format_item(self, item):
+ d = item.as_dict()
+ d = {
+ 'id': d['id'],
+ 'resource_group': self.resource_group,
+ 'name': d['name'],
+ 'sku': d['sku'],
+ 'location': d['location'],
+ 'storage_mb': d['storage_profile']['storage_mb'],
+ 'version': d['version'],
+ 'enforce_ssl': (d['ssl_enforcement'] == 'Enabled'),
+ 'admin_username': d['administrator_login'],
+ 'user_visible_state': d['user_visible_state'],
+ 'fully_qualified_domain_name': d['fully_qualified_domain_name'],
+ 'tags': d.get('tags')
+ }
+
+ return d
+
+
+def main():
+ AzureRMMariaDbServerInfo()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_resource.py b/test/support/integration/plugins/modules/azure_rm_resource.py
new file mode 100644
index 00000000..6ea3e3bb
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_resource.py
@@ -0,0 +1,427 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_resource
+version_added: "2.6"
+short_description: Create any Azure resource
+description:
+ - Create, update or delete any Azure resource using Azure REST API.
+ - This module gives access to resources that are not supported via Ansible modules.
+ - Refer to U(https://docs.microsoft.com/en-us/rest/api/) regarding details related to specific resource REST API.
+
+options:
+ url:
+ description:
+ - Azure RM Resource URL.
+ api_version:
+ description:
+ - Specific API version to be used.
+ provider:
+ description:
+ - Provider type.
+ - Required if URL is not specified.
+ resource_group:
+ description:
+ - Resource group to be used.
+ - Required if URL is not specified.
+ resource_type:
+ description:
+ - Resource type.
+ - Required if URL is not specified.
+ resource_name:
+ description:
+ - Resource name.
+ - Required if URL Is not specified.
+ subresource:
+ description:
+ - List of subresources.
+ suboptions:
+ namespace:
+ description:
+ - Subresource namespace.
+ type:
+ description:
+ - Subresource type.
+ name:
+ description:
+ - Subresource name.
+ body:
+ description:
+ - The body of the HTTP request/response to the web service.
+ method:
+ description:
+ - The HTTP method of the request or response. It must be uppercase.
+ choices:
+ - GET
+ - PUT
+ - POST
+ - HEAD
+ - PATCH
+ - DELETE
+ - MERGE
+ default: "PUT"
+ status_code:
+ description:
+ - A valid, numeric, HTTP status code that signifies success of the request. Can also be comma separated list of status codes.
+ type: list
+ default: [ 200, 201, 202 ]
+ idempotency:
+ description:
+ - If enabled, idempotency check will be done by using I(method=GET) first and then comparing with I(body).
+ default: no
+ type: bool
+ polling_timeout:
+ description:
+ - If enabled, idempotency check will be done by using I(method=GET) first and then comparing with I(body).
+ default: 0
+ type: int
+ version_added: "2.8"
+ polling_interval:
+ description:
+ - If enabled, idempotency check will be done by using I(method=GET) first and then comparing with I(body).
+ default: 60
+ type: int
+ version_added: "2.8"
+ state:
+ description:
+ - Assert the state of the resource. Use C(present) to create or update resource or C(absent) to delete resource.
+ default: present
+ choices:
+ - absent
+ - present
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Zim Kalinowski (@zikalino)
+
+'''
+
+EXAMPLES = '''
+ - name: Update scaleset info using azure_rm_resource
+ azure_rm_resource:
+ resource_group: myResourceGroup
+ provider: compute
+ resource_type: virtualmachinescalesets
+ resource_name: myVmss
+ api_version: "2017-12-01"
+ body: { body }
+'''
+
+RETURN = '''
+response:
+ description:
+ - Response specific to resource type.
+ returned: always
+ type: complex
+ contains:
+ id:
+ description:
+ - Resource ID.
+ type: str
+ returned: always
+ sample: "/subscriptions/xxxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Storage/storageAccounts/staccb57dc95183"
+ kind:
+ description:
+ - The kind of storage.
+ type: str
+ returned: always
+ sample: Storage
+ location:
+ description:
+ - The resource location, defaults to location of the resource group.
+ type: str
+ returned: always
+ sample: eastus
+ name:
+ description:
+ The storage account name.
+ type: str
+ returned: always
+ sample: staccb57dc95183
+ properties:
+ description:
+ - The storage account's related properties.
+ type: dict
+ returned: always
+ sample: {
+ "creationTime": "2019-06-13T06:34:33.0996676Z",
+ "encryption": {
+ "keySource": "Microsoft.Storage",
+ "services": {
+ "blob": {
+ "enabled": true,
+ "lastEnabledTime": "2019-06-13T06:34:33.1934074Z"
+ },
+ "file": {
+ "enabled": true,
+ "lastEnabledTime": "2019-06-13T06:34:33.1934074Z"
+ }
+ }
+ },
+ "networkAcls": {
+ "bypass": "AzureServices",
+ "defaultAction": "Allow",
+ "ipRules": [],
+ "virtualNetworkRules": []
+ },
+ "primaryEndpoints": {
+ "blob": "https://staccb57dc95183.blob.core.windows.net/",
+ "file": "https://staccb57dc95183.file.core.windows.net/",
+ "queue": "https://staccb57dc95183.queue.core.windows.net/",
+ "table": "https://staccb57dc95183.table.core.windows.net/"
+ },
+ "primaryLocation": "eastus",
+ "provisioningState": "Succeeded",
+ "secondaryLocation": "westus",
+ "statusOfPrimary": "available",
+ "statusOfSecondary": "available",
+ "supportsHttpsTrafficOnly": false
+ }
+ sku:
+ description:
+ - The storage account SKU.
+ type: dict
+ returned: always
+ sample: {
+ "name": "Standard_GRS",
+ "tier": "Standard"
+ }
+ tags:
+ description:
+ - Resource tags.
+ type: dict
+ returned: always
+ sample: { 'key1': 'value1' }
+ type:
+ description:
+ - The resource type.
+ type: str
+ returned: always
+ sample: "Microsoft.Storage/storageAccounts"
+
+'''
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+from ansible.module_utils.azure_rm_common_rest import GenericRestClient
+from ansible.module_utils.common.dict_transformations import dict_merge
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.service_client import ServiceClient
+ from msrestazure.tools import resource_id, is_valid_resource_id
+ import json
+
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class AzureRMResource(AzureRMModuleBase):
+ def __init__(self):
+ # define user inputs into argument
+ self.module_arg_spec = dict(
+ url=dict(
+ type='str'
+ ),
+ provider=dict(
+ type='str',
+ ),
+ resource_group=dict(
+ type='str',
+ ),
+ resource_type=dict(
+ type='str',
+ ),
+ resource_name=dict(
+ type='str',
+ ),
+ subresource=dict(
+ type='list',
+ default=[]
+ ),
+ api_version=dict(
+ type='str'
+ ),
+ method=dict(
+ type='str',
+ default='PUT',
+ choices=["GET", "PUT", "POST", "HEAD", "PATCH", "DELETE", "MERGE"]
+ ),
+ body=dict(
+ type='raw'
+ ),
+ status_code=dict(
+ type='list',
+ default=[200, 201, 202]
+ ),
+ idempotency=dict(
+ type='bool',
+ default=False
+ ),
+ polling_timeout=dict(
+ type='int',
+ default=0
+ ),
+ polling_interval=dict(
+ type='int',
+ default=60
+ ),
+ state=dict(
+ type='str',
+ default='present',
+ choices=['present', 'absent']
+ )
+ )
+ # store the results of the module operation
+ self.results = dict(
+ changed=False,
+ response=None
+ )
+ self.mgmt_client = None
+ self.url = None
+ self.api_version = None
+ self.provider = None
+ self.resource_group = None
+ self.resource_type = None
+ self.resource_name = None
+ self.subresource_type = None
+ self.subresource_name = None
+ self.subresource = []
+ self.method = None
+ self.status_code = []
+ self.idempotency = False
+ self.polling_timeout = None
+ self.polling_interval = None
+ self.state = None
+ self.body = None
+ super(AzureRMResource, self).__init__(self.module_arg_spec, supports_tags=False)
+
+ def exec_module(self, **kwargs):
+ for key in self.module_arg_spec:
+ setattr(self, key, kwargs[key])
+ self.mgmt_client = self.get_mgmt_svc_client(GenericRestClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+
+ if self.state == 'absent':
+ self.method = 'DELETE'
+ self.status_code.append(204)
+
+ if self.url is None:
+ orphan = None
+ rargs = dict()
+ rargs['subscription'] = self.subscription_id
+ rargs['resource_group'] = self.resource_group
+ if not (self.provider is None or self.provider.lower().startswith('.microsoft')):
+ rargs['namespace'] = "Microsoft." + self.provider
+ else:
+ rargs['namespace'] = self.provider
+
+ if self.resource_type is not None and self.resource_name is not None:
+ rargs['type'] = self.resource_type
+ rargs['name'] = self.resource_name
+ for i in range(len(self.subresource)):
+ resource_ns = self.subresource[i].get('namespace', None)
+ resource_type = self.subresource[i].get('type', None)
+ resource_name = self.subresource[i].get('name', None)
+ if resource_type is not None and resource_name is not None:
+ rargs['child_namespace_' + str(i + 1)] = resource_ns
+ rargs['child_type_' + str(i + 1)] = resource_type
+ rargs['child_name_' + str(i + 1)] = resource_name
+ else:
+ orphan = resource_type
+ else:
+ orphan = self.resource_type
+
+ self.url = resource_id(**rargs)
+
+ if orphan is not None:
+ self.url += '/' + orphan
+
+ # if api_version was not specified, get latest one
+ if not self.api_version:
+ try:
+ # extract provider and resource type
+ if "/providers/" in self.url:
+ provider = self.url.split("/providers/")[1].split("/")[0]
+ resourceType = self.url.split(provider + "/")[1].split("/")[0]
+ url = "/subscriptions/" + self.subscription_id + "/providers/" + provider
+ api_versions = json.loads(self.mgmt_client.query(url, "GET", {'api-version': '2015-01-01'}, None, None, [200], 0, 0).text)
+ for rt in api_versions['resourceTypes']:
+ if rt['resourceType'].lower() == resourceType.lower():
+ self.api_version = rt['apiVersions'][0]
+ break
+ else:
+ # if there's no provider in API version, assume Microsoft.Resources
+ self.api_version = '2018-05-01'
+ if not self.api_version:
+ self.fail("Couldn't find api version for {0}/{1}".format(provider, resourceType))
+ except Exception as exc:
+ self.fail("Failed to obtain API version: {0}".format(str(exc)))
+
+ query_parameters = {}
+ query_parameters['api-version'] = self.api_version
+
+ header_parameters = {}
+ header_parameters['Content-Type'] = 'application/json; charset=utf-8'
+
+ needs_update = True
+ response = None
+
+ if self.idempotency:
+ original = self.mgmt_client.query(self.url, "GET", query_parameters, None, None, [200, 404], 0, 0)
+
+ if original.status_code == 404:
+ if self.state == 'absent':
+ needs_update = False
+ else:
+ try:
+ response = json.loads(original.text)
+ needs_update = (dict_merge(response, self.body) != response)
+ except Exception:
+ pass
+
+ if needs_update:
+ response = self.mgmt_client.query(self.url,
+ self.method,
+ query_parameters,
+ header_parameters,
+ self.body,
+ self.status_code,
+ self.polling_timeout,
+ self.polling_interval)
+ if self.state == 'present':
+ try:
+ response = json.loads(response.text)
+ except Exception:
+ response = response.text
+ else:
+ response = None
+
+ self.results['response'] = response
+ self.results['changed'] = needs_update
+
+ return self.results
+
+
+def main():
+ AzureRMResource()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_resource_info.py b/test/support/integration/plugins/modules/azure_rm_resource_info.py
new file mode 100644
index 00000000..f797f662
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_resource_info.py
@@ -0,0 +1,432 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_resource_info
+version_added: "2.9"
+short_description: Generic facts of Azure resources
+description:
+ - Obtain facts of any resource using Azure REST API.
+ - This module gives access to resources that are not supported via Ansible modules.
+ - Refer to U(https://docs.microsoft.com/en-us/rest/api/) regarding details related to specific resource REST API.
+
+options:
+ url:
+ description:
+ - Azure RM Resource URL.
+ api_version:
+ description:
+ - Specific API version to be used.
+ provider:
+ description:
+ - Provider type, should be specified in no URL is given.
+ resource_group:
+ description:
+ - Resource group to be used.
+ - Required if URL is not specified.
+ resource_type:
+ description:
+ - Resource type.
+ resource_name:
+ description:
+ - Resource name.
+ subresource:
+ description:
+ - List of subresources.
+ suboptions:
+ namespace:
+ description:
+ - Subresource namespace.
+ type:
+ description:
+ - Subresource type.
+ name:
+ description:
+ - Subresource name.
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Zim Kalinowski (@zikalino)
+
+'''
+
+EXAMPLES = '''
+ - name: Get scaleset info
+ azure_rm_resource_info:
+ resource_group: myResourceGroup
+ provider: compute
+ resource_type: virtualmachinescalesets
+ resource_name: myVmss
+ api_version: "2017-12-01"
+
+ - name: Query all the resources in the resource group
+ azure_rm_resource_info:
+ resource_group: "{{ resource_group }}"
+ resource_type: resources
+'''
+
+RETURN = '''
+response:
+ description:
+ - Response specific to resource type.
+ returned: always
+ type: complex
+ contains:
+ id:
+ description:
+ - Id of the Azure resource.
+ type: str
+ returned: always
+ sample: "/subscriptions/xxxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Compute/virtualMachines/myVM"
+ location:
+ description:
+ - Resource location.
+ type: str
+ returned: always
+ sample: eastus
+ name:
+ description:
+ - Resource name.
+ type: str
+ returned: always
+ sample: myVM
+ properties:
+ description:
+ - Specifies the virtual machine's property.
+ type: complex
+ returned: always
+ contains:
+ diagnosticsProfile:
+ description:
+ - Specifies the boot diagnostic settings state.
+ type: complex
+ returned: always
+ contains:
+ bootDiagnostics:
+ description:
+ - A debugging feature, which to view Console Output and Screenshot to diagnose VM status.
+ type: dict
+ returned: always
+ sample: {
+ "enabled": true,
+ "storageUri": "https://vxisurgdiag.blob.core.windows.net/"
+ }
+ hardwareProfile:
+ description:
+ - Specifies the hardware settings for the virtual machine.
+ type: dict
+ returned: always
+ sample: {
+ "vmSize": "Standard_D2s_v3"
+ }
+ networkProfile:
+ description:
+ - Specifies the network interfaces of the virtual machine.
+ type: complex
+ returned: always
+ contains:
+ networkInterfaces:
+ description:
+ - Describes a network interface reference.
+ type: list
+ returned: always
+ sample:
+ - {
+ "id": "/subscriptions/xxxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Network/networkInterfaces/myvm441"
+ }
+ osProfile:
+ description:
+ - Specifies the operating system settings for the virtual machine.
+ type: complex
+ returned: always
+ contains:
+ adminUsername:
+ description:
+ - Specifies the name of the administrator account.
+ type: str
+ returned: always
+ sample: azureuser
+ allowExtensionOperations:
+ description:
+ - Specifies whether extension operations should be allowed on the virtual machine.
+ - This may only be set to False when no extensions are present on the virtual machine.
+ type: bool
+ returned: always
+ sample: true
+ computerName:
+ description:
+ - Specifies the host OS name of the virtual machine.
+ type: str
+ returned: always
+ sample: myVM
+ requireGuestProvisionSignale:
+ description:
+ - Specifies the host require guest provision signal or not.
+ type: bool
+ returned: always
+ sample: true
+ secrets:
+ description:
+ - Specifies set of certificates that should be installed onto the virtual machine.
+ type: list
+ returned: always
+ sample: []
+ linuxConfiguration:
+ description:
+ - Specifies the Linux operating system settings on the virtual machine.
+ type: dict
+ returned: when OS type is Linux
+ sample: {
+ "disablePasswordAuthentication": false,
+ "provisionVMAgent": true
+ }
+ provisioningState:
+ description:
+ - The provisioning state.
+ type: str
+ returned: always
+ sample: Succeeded
+ vmID:
+ description:
+ - Specifies the VM unique ID which is a 128-bits identifier that is encoded and stored in all Azure laaS VMs SMBIOS.
+ - It can be read using platform BIOS commands.
+ type: str
+ returned: always
+ sample: "eb86d9bb-6725-4787-a487-2e497d5b340c"
+ storageProfile:
+ description:
+ - Specifies the storage account type for the managed disk.
+ type: complex
+ returned: always
+ contains:
+ dataDisks:
+ description:
+ - Specifies the parameters that are used to add a data disk to virtual machine.
+ type: list
+ returned: always
+ sample:
+ - {
+ "caching": "None",
+ "createOption": "Attach",
+ "diskSizeGB": 1023,
+ "lun": 2,
+ "managedDisk": {
+ "id": "/subscriptions/xxxx....xxxx/resourceGroups/V-XISURG/providers/Microsoft.Compute/disks/testdisk2",
+ "storageAccountType": "StandardSSD_LRS"
+ },
+ "name": "testdisk2"
+ }
+ - {
+ "caching": "None",
+ "createOption": "Attach",
+ "diskSizeGB": 1023,
+ "lun": 1,
+ "managedDisk": {
+ "id": "/subscriptions/xxxx...xxxx/resourceGroups/V-XISURG/providers/Microsoft.Compute/disks/testdisk3",
+ "storageAccountType": "StandardSSD_LRS"
+ },
+ "name": "testdisk3"
+ }
+
+ imageReference:
+ description:
+ - Specifies information about the image to use.
+ type: dict
+ returned: always
+ sample: {
+ "offer": "UbuntuServer",
+ "publisher": "Canonical",
+ "sku": "18.04-LTS",
+ "version": "latest"
+ }
+ osDisk:
+ description:
+ - Specifies information about the operating system disk used by the virtual machine.
+ type: dict
+ returned: always
+ sample: {
+ "caching": "ReadWrite",
+ "createOption": "FromImage",
+ "diskSizeGB": 30,
+ "managedDisk": {
+ "id": "/subscriptions/xxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Compute/disks/myVM_disk1_xxx",
+ "storageAccountType": "Premium_LRS"
+ },
+ "name": "myVM_disk1_xxx",
+ "osType": "Linux"
+ }
+ type:
+ description:
+ - The type of identity used for the virtual machine.
+ type: str
+ returned: always
+ sample: "Microsoft.Compute/virtualMachines"
+'''
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+from ansible.module_utils.azure_rm_common_rest import GenericRestClient
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.service_client import ServiceClient
+ from msrestazure.tools import resource_id, is_valid_resource_id
+ import json
+
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+
+class AzureRMResourceInfo(AzureRMModuleBase):
+ def __init__(self):
+ # define user inputs into argument
+ self.module_arg_spec = dict(
+ url=dict(
+ type='str'
+ ),
+ provider=dict(
+ type='str'
+ ),
+ resource_group=dict(
+ type='str'
+ ),
+ resource_type=dict(
+ type='str'
+ ),
+ resource_name=dict(
+ type='str'
+ ),
+ subresource=dict(
+ type='list',
+ default=[]
+ ),
+ api_version=dict(
+ type='str'
+ )
+ )
+ # store the results of the module operation
+ self.results = dict(
+ response=[]
+ )
+ self.mgmt_client = None
+ self.url = None
+ self.api_version = None
+ self.provider = None
+ self.resource_group = None
+ self.resource_type = None
+ self.resource_name = None
+ self.subresource = []
+ super(AzureRMResourceInfo, self).__init__(self.module_arg_spec, supports_tags=False)
+
+ def exec_module(self, **kwargs):
+ is_old_facts = self.module._name == 'azure_rm_resource_facts'
+ if is_old_facts:
+ self.module.deprecate("The 'azure_rm_resource_facts' module has been renamed to 'azure_rm_resource_info'",
+ version='2.13', collection_name='ansible.builtin')
+
+ for key in self.module_arg_spec:
+ setattr(self, key, kwargs[key])
+ self.mgmt_client = self.get_mgmt_svc_client(GenericRestClient,
+ base_url=self._cloud_environment.endpoints.resource_manager)
+
+ if self.url is None:
+ orphan = None
+ rargs = dict()
+ rargs['subscription'] = self.subscription_id
+ rargs['resource_group'] = self.resource_group
+ if not (self.provider is None or self.provider.lower().startswith('.microsoft')):
+ rargs['namespace'] = "Microsoft." + self.provider
+ else:
+ rargs['namespace'] = self.provider
+
+ if self.resource_type is not None and self.resource_name is not None:
+ rargs['type'] = self.resource_type
+ rargs['name'] = self.resource_name
+ for i in range(len(self.subresource)):
+ resource_ns = self.subresource[i].get('namespace', None)
+ resource_type = self.subresource[i].get('type', None)
+ resource_name = self.subresource[i].get('name', None)
+ if resource_type is not None and resource_name is not None:
+ rargs['child_namespace_' + str(i + 1)] = resource_ns
+ rargs['child_type_' + str(i + 1)] = resource_type
+ rargs['child_name_' + str(i + 1)] = resource_name
+ else:
+ orphan = resource_type
+ else:
+ orphan = self.resource_type
+
+ self.url = resource_id(**rargs)
+
+ if orphan is not None:
+ self.url += '/' + orphan
+
+ # if api_version was not specified, get latest one
+ if not self.api_version:
+ try:
+ # extract provider and resource type
+ if "/providers/" in self.url:
+ provider = self.url.split("/providers/")[1].split("/")[0]
+ resourceType = self.url.split(provider + "/")[1].split("/")[0]
+ url = "/subscriptions/" + self.subscription_id + "/providers/" + provider
+ api_versions = json.loads(self.mgmt_client.query(url, "GET", {'api-version': '2015-01-01'}, None, None, [200], 0, 0).text)
+ for rt in api_versions['resourceTypes']:
+ if rt['resourceType'].lower() == resourceType.lower():
+ self.api_version = rt['apiVersions'][0]
+ break
+ else:
+ # if there's no provider in API version, assume Microsoft.Resources
+ self.api_version = '2018-05-01'
+ if not self.api_version:
+ self.fail("Couldn't find api version for {0}/{1}".format(provider, resourceType))
+ except Exception as exc:
+ self.fail("Failed to obtain API version: {0}".format(str(exc)))
+
+ self.results['url'] = self.url
+
+ query_parameters = {}
+ query_parameters['api-version'] = self.api_version
+
+ header_parameters = {}
+ header_parameters['Content-Type'] = 'application/json; charset=utf-8'
+ skiptoken = None
+
+ while True:
+ if skiptoken:
+ query_parameters['skiptoken'] = skiptoken
+ response = self.mgmt_client.query(self.url, "GET", query_parameters, header_parameters, None, [200, 404], 0, 0)
+ try:
+ response = json.loads(response.text)
+ if isinstance(response, dict):
+ if response.get('value'):
+ self.results['response'] = self.results['response'] + response['value']
+ skiptoken = response.get('nextLink')
+ else:
+ self.results['response'] = self.results['response'] + [response]
+ except Exception as e:
+ self.fail('Failed to parse response: ' + str(e))
+ if not skiptoken:
+ break
+ return self.results
+
+
+def main():
+ AzureRMResourceInfo()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_storageaccount.py b/test/support/integration/plugins/modules/azure_rm_storageaccount.py
new file mode 100644
index 00000000..d4158bbd
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_storageaccount.py
@@ -0,0 +1,684 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2016 Matt Davis, <mdavis@ansible.com>
+# Chris Houseknecht, <house@redhat.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_storageaccount
+version_added: "2.1"
+short_description: Manage Azure storage accounts
+description:
+ - Create, update or delete a storage account.
+options:
+ resource_group:
+ description:
+ - Name of the resource group to use.
+ required: true
+ aliases:
+ - resource_group_name
+ name:
+ description:
+ - Name of the storage account to update or create.
+ state:
+ description:
+ - State of the storage account. Use C(present) to create or update a storage account and use C(absent) to delete an account.
+ default: present
+ choices:
+ - absent
+ - present
+ location:
+ description:
+ - Valid Azure location. Defaults to location of the resource group.
+ account_type:
+ description:
+ - Type of storage account. Required when creating a storage account.
+ - C(Standard_ZRS) and C(Premium_LRS) accounts cannot be changed to other account types.
+ - Other account types cannot be changed to C(Standard_ZRS) or C(Premium_LRS).
+ choices:
+ - Premium_LRS
+ - Standard_GRS
+ - Standard_LRS
+ - StandardSSD_LRS
+ - Standard_RAGRS
+ - Standard_ZRS
+ - Premium_ZRS
+ aliases:
+ - type
+ custom_domain:
+ description:
+ - User domain assigned to the storage account.
+ - Must be a dictionary with I(name) and I(use_sub_domain) keys where I(name) is the CNAME source.
+ - Only one custom domain is supported per storage account at this time.
+ - To clear the existing custom domain, use an empty string for the custom domain name property.
+ - Can be added to an existing storage account. Will be ignored during storage account creation.
+ aliases:
+ - custom_dns_domain_suffix
+ kind:
+ description:
+ - The kind of storage.
+ default: 'Storage'
+ choices:
+ - Storage
+ - StorageV2
+ - BlobStorage
+ version_added: "2.2"
+ access_tier:
+ description:
+ - The access tier for this storage account. Required when I(kind=BlobStorage).
+ choices:
+ - Hot
+ - Cool
+ version_added: "2.4"
+ force_delete_nonempty:
+ description:
+ - Attempt deletion if resource already exists and cannot be updated.
+ type: bool
+ aliases:
+ - force
+ https_only:
+ description:
+ - Allows https traffic only to storage service when set to C(true).
+ type: bool
+ version_added: "2.8"
+ blob_cors:
+ description:
+ - Specifies CORS rules for the Blob service.
+ - You can include up to five CorsRule elements in the request.
+ - If no blob_cors elements are included in the argument list, nothing about CORS will be changed.
+ - If you want to delete all CORS rules and disable CORS for the Blob service, explicitly set I(blob_cors=[]).
+ type: list
+ version_added: "2.8"
+ suboptions:
+ allowed_origins:
+ description:
+ - A list of origin domains that will be allowed via CORS, or "*" to allow all domains.
+ type: list
+ required: true
+ allowed_methods:
+ description:
+ - A list of HTTP methods that are allowed to be executed by the origin.
+ type: list
+ required: true
+ max_age_in_seconds:
+ description:
+ - The number of seconds that the client/browser should cache a preflight response.
+ type: int
+ required: true
+ exposed_headers:
+ description:
+ - A list of response headers to expose to CORS clients.
+ type: list
+ required: true
+ allowed_headers:
+ description:
+ - A list of headers allowed to be part of the cross-origin request.
+ type: list
+ required: true
+
+extends_documentation_fragment:
+ - azure
+ - azure_tags
+
+author:
+ - Chris Houseknecht (@chouseknecht)
+ - Matt Davis (@nitzmahone)
+'''
+
+EXAMPLES = '''
+ - name: remove account, if it exists
+ azure_rm_storageaccount:
+ resource_group: myResourceGroup
+ name: clh0002
+ state: absent
+
+ - name: create an account
+ azure_rm_storageaccount:
+ resource_group: myResourceGroup
+ name: clh0002
+ type: Standard_RAGRS
+ tags:
+ testing: testing
+ delete: on-exit
+
+ - name: create an account with blob CORS
+ azure_rm_storageaccount:
+ resource_group: myResourceGroup
+ name: clh002
+ type: Standard_RAGRS
+ blob_cors:
+ - allowed_origins:
+ - http://www.example.com/
+ allowed_methods:
+ - GET
+ - POST
+ allowed_headers:
+ - x-ms-meta-data*
+ - x-ms-meta-target*
+ - x-ms-meta-abc
+ exposed_headers:
+ - x-ms-meta-*
+ max_age_in_seconds: 200
+'''
+
+
+RETURN = '''
+state:
+ description:
+ - Current state of the storage account.
+ returned: always
+ type: complex
+ contains:
+ account_type:
+ description:
+ - Type of storage account.
+ returned: always
+ type: str
+ sample: Standard_RAGRS
+ custom_domain:
+ description:
+ - User domain assigned to the storage account.
+ returned: always
+ type: complex
+ contains:
+ name:
+ description:
+ - CNAME source.
+ returned: always
+ type: str
+ sample: testaccount
+ use_sub_domain:
+ description:
+ - Whether to use sub domain.
+ returned: always
+ type: bool
+ sample: true
+ id:
+ description:
+ - Resource ID.
+ returned: always
+ type: str
+ sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Storage/storageAccounts/clh0003"
+ location:
+ description:
+ - Valid Azure location. Defaults to location of the resource group.
+ returned: always
+ type: str
+ sample: eastus2
+ name:
+ description:
+ - Name of the storage account to update or create.
+ returned: always
+ type: str
+ sample: clh0003
+ primary_endpoints:
+ description:
+ - The URLs to retrieve the public I(blob), I(queue), or I(table) object from the primary location.
+ returned: always
+ type: dict
+ sample: {
+ "blob": "https://clh0003.blob.core.windows.net/",
+ "queue": "https://clh0003.queue.core.windows.net/",
+ "table": "https://clh0003.table.core.windows.net/"
+ }
+ primary_location:
+ description:
+ - The location of the primary data center for the storage account.
+ returned: always
+ type: str
+ sample: eastus2
+ provisioning_state:
+ description:
+ - The status of the storage account.
+ - Possible values include C(Creating), C(ResolvingDNS), C(Succeeded).
+ returned: always
+ type: str
+ sample: Succeeded
+ resource_group:
+ description:
+ - The resource group's name.
+ returned: always
+ type: str
+ sample: Testing
+ secondary_endpoints:
+ description:
+ - The URLs to retrieve the public I(blob), I(queue), or I(table) object from the secondary location.
+ returned: always
+ type: dict
+ sample: {
+ "blob": "https://clh0003-secondary.blob.core.windows.net/",
+ "queue": "https://clh0003-secondary.queue.core.windows.net/",
+ "table": "https://clh0003-secondary.table.core.windows.net/"
+ }
+ secondary_location:
+ description:
+ - The location of the geo-replicated secondary for the storage account.
+ returned: always
+ type: str
+ sample: centralus
+ status_of_primary:
+ description:
+ - The status of the primary location of the storage account; either C(available) or C(unavailable).
+ returned: always
+ type: str
+ sample: available
+ status_of_secondary:
+ description:
+ - The status of the secondary location of the storage account; either C(available) or C(unavailable).
+ returned: always
+ type: str
+ sample: available
+ tags:
+ description:
+ - Resource tags.
+ returned: always
+ type: dict
+ sample: { 'tags1': 'value1' }
+ type:
+ description:
+ - The storage account type.
+ returned: always
+ type: str
+ sample: "Microsoft.Storage/storageAccounts"
+'''
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from azure.storage.cloudstorageaccount import CloudStorageAccount
+ from azure.common import AzureMissingResourceHttpError
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+import copy
+from ansible.module_utils.azure_rm_common import AZURE_SUCCESS_STATE, AzureRMModuleBase
+from ansible.module_utils._text import to_native
+
+cors_rule_spec = dict(
+ allowed_origins=dict(type='list', elements='str', required=True),
+ allowed_methods=dict(type='list', elements='str', required=True),
+ max_age_in_seconds=dict(type='int', required=True),
+ exposed_headers=dict(type='list', elements='str', required=True),
+ allowed_headers=dict(type='list', elements='str', required=True),
+)
+
+
+def compare_cors(cors1, cors2):
+ if len(cors1) != len(cors2):
+ return False
+ copy2 = copy.copy(cors2)
+ for rule1 in cors1:
+ matched = False
+ for rule2 in copy2:
+ if (rule1['max_age_in_seconds'] == rule2['max_age_in_seconds']
+ and set(rule1['allowed_methods']) == set(rule2['allowed_methods'])
+ and set(rule1['allowed_origins']) == set(rule2['allowed_origins'])
+ and set(rule1['allowed_headers']) == set(rule2['allowed_headers'])
+ and set(rule1['exposed_headers']) == set(rule2['exposed_headers'])):
+ matched = True
+ copy2.remove(rule2)
+ if not matched:
+ return False
+ return True
+
+
+class AzureRMStorageAccount(AzureRMModuleBase):
+
+ def __init__(self):
+
+ self.module_arg_spec = dict(
+ account_type=dict(type='str',
+ choices=['Premium_LRS', 'Standard_GRS', 'Standard_LRS', 'StandardSSD_LRS', 'Standard_RAGRS', 'Standard_ZRS', 'Premium_ZRS'],
+ aliases=['type']),
+ custom_domain=dict(type='dict', aliases=['custom_dns_domain_suffix']),
+ location=dict(type='str'),
+ name=dict(type='str', required=True),
+ resource_group=dict(required=True, type='str', aliases=['resource_group_name']),
+ state=dict(default='present', choices=['present', 'absent']),
+ force_delete_nonempty=dict(type='bool', default=False, aliases=['force']),
+ tags=dict(type='dict'),
+ kind=dict(type='str', default='Storage', choices=['Storage', 'StorageV2', 'BlobStorage']),
+ access_tier=dict(type='str', choices=['Hot', 'Cool']),
+ https_only=dict(type='bool', default=False),
+ blob_cors=dict(type='list', options=cors_rule_spec, elements='dict')
+ )
+
+ self.results = dict(
+ changed=False,
+ state=dict()
+ )
+
+ self.account_dict = None
+ self.resource_group = None
+ self.name = None
+ self.state = None
+ self.location = None
+ self.account_type = None
+ self.custom_domain = None
+ self.tags = None
+ self.force_delete_nonempty = None
+ self.kind = None
+ self.access_tier = None
+ self.https_only = None
+ self.blob_cors = None
+
+ super(AzureRMStorageAccount, self).__init__(self.module_arg_spec,
+ supports_check_mode=True)
+
+ def exec_module(self, **kwargs):
+
+ for key in list(self.module_arg_spec.keys()) + ['tags']:
+ setattr(self, key, kwargs[key])
+
+ resource_group = self.get_resource_group(self.resource_group)
+ if not self.location:
+ # Set default location
+ self.location = resource_group.location
+
+ if len(self.name) < 3 or len(self.name) > 24:
+ self.fail("Parameter error: name length must be between 3 and 24 characters.")
+
+ if self.custom_domain:
+ if self.custom_domain.get('name', None) is None:
+ self.fail("Parameter error: expecting custom_domain to have a name attribute of type string.")
+ if self.custom_domain.get('use_sub_domain', None) is None:
+ self.fail("Parameter error: expecting custom_domain to have a use_sub_domain "
+ "attribute of type boolean.")
+
+ self.account_dict = self.get_account()
+
+ if self.state == 'present' and self.account_dict and \
+ self.account_dict['provisioning_state'] != AZURE_SUCCESS_STATE:
+ self.fail("Error: storage account {0} has not completed provisioning. State is {1}. Expecting state "
+ "to be {2}.".format(self.name, self.account_dict['provisioning_state'], AZURE_SUCCESS_STATE))
+
+ if self.account_dict is not None:
+ self.results['state'] = self.account_dict
+ else:
+ self.results['state'] = dict()
+
+ if self.state == 'present':
+ if not self.account_dict:
+ self.results['state'] = self.create_account()
+ else:
+ self.update_account()
+ elif self.state == 'absent' and self.account_dict:
+ self.delete_account()
+ self.results['state'] = dict(Status='Deleted')
+
+ return self.results
+
+ def check_name_availability(self):
+ self.log('Checking name availability for {0}'.format(self.name))
+ try:
+ response = self.storage_client.storage_accounts.check_name_availability(self.name)
+ except CloudError as e:
+ self.log('Error attempting to validate name.')
+ self.fail("Error checking name availability: {0}".format(str(e)))
+ if not response.name_available:
+ self.log('Error name not available.')
+ self.fail("{0} - {1}".format(response.message, response.reason))
+
+ def get_account(self):
+ self.log('Get properties for account {0}'.format(self.name))
+ account_obj = None
+ blob_service_props = None
+ account_dict = None
+
+ try:
+ account_obj = self.storage_client.storage_accounts.get_properties(self.resource_group, self.name)
+ blob_service_props = self.storage_client.blob_services.get_service_properties(self.resource_group, self.name)
+ except CloudError:
+ pass
+
+ if account_obj:
+ account_dict = self.account_obj_to_dict(account_obj, blob_service_props)
+
+ return account_dict
+
+ def account_obj_to_dict(self, account_obj, blob_service_props=None):
+ account_dict = dict(
+ id=account_obj.id,
+ name=account_obj.name,
+ location=account_obj.location,
+ resource_group=self.resource_group,
+ type=account_obj.type,
+ access_tier=(account_obj.access_tier.value
+ if account_obj.access_tier is not None else None),
+ sku_tier=account_obj.sku.tier.value,
+ sku_name=account_obj.sku.name.value,
+ provisioning_state=account_obj.provisioning_state.value,
+ secondary_location=account_obj.secondary_location,
+ status_of_primary=(account_obj.status_of_primary.value
+ if account_obj.status_of_primary is not None else None),
+ status_of_secondary=(account_obj.status_of_secondary.value
+ if account_obj.status_of_secondary is not None else None),
+ primary_location=account_obj.primary_location,
+ https_only=account_obj.enable_https_traffic_only
+ )
+ account_dict['custom_domain'] = None
+ if account_obj.custom_domain:
+ account_dict['custom_domain'] = dict(
+ name=account_obj.custom_domain.name,
+ use_sub_domain=account_obj.custom_domain.use_sub_domain
+ )
+
+ account_dict['primary_endpoints'] = None
+ if account_obj.primary_endpoints:
+ account_dict['primary_endpoints'] = dict(
+ blob=account_obj.primary_endpoints.blob,
+ queue=account_obj.primary_endpoints.queue,
+ table=account_obj.primary_endpoints.table
+ )
+ account_dict['secondary_endpoints'] = None
+ if account_obj.secondary_endpoints:
+ account_dict['secondary_endpoints'] = dict(
+ blob=account_obj.secondary_endpoints.blob,
+ queue=account_obj.secondary_endpoints.queue,
+ table=account_obj.secondary_endpoints.table
+ )
+ account_dict['tags'] = None
+ if account_obj.tags:
+ account_dict['tags'] = account_obj.tags
+ if blob_service_props and blob_service_props.cors and blob_service_props.cors.cors_rules:
+ account_dict['blob_cors'] = [dict(
+ allowed_origins=[to_native(y) for y in x.allowed_origins],
+ allowed_methods=[to_native(y) for y in x.allowed_methods],
+ max_age_in_seconds=x.max_age_in_seconds,
+ exposed_headers=[to_native(y) for y in x.exposed_headers],
+ allowed_headers=[to_native(y) for y in x.allowed_headers]
+ ) for x in blob_service_props.cors.cors_rules]
+ return account_dict
+
+ def update_account(self):
+ self.log('Update storage account {0}'.format(self.name))
+ if bool(self.https_only) != bool(self.account_dict.get('https_only')):
+ self.results['changed'] = True
+ self.account_dict['https_only'] = self.https_only
+ if not self.check_mode:
+ try:
+ parameters = self.storage_models.StorageAccountUpdateParameters(enable_https_traffic_only=self.https_only)
+ self.storage_client.storage_accounts.update(self.resource_group,
+ self.name,
+ parameters)
+ except Exception as exc:
+ self.fail("Failed to update account type: {0}".format(str(exc)))
+
+ if self.account_type:
+ if self.account_type != self.account_dict['sku_name']:
+ # change the account type
+ SkuName = self.storage_models.SkuName
+ if self.account_dict['sku_name'] in [SkuName.premium_lrs, SkuName.standard_zrs]:
+ self.fail("Storage accounts of type {0} and {1} cannot be changed.".format(
+ SkuName.premium_lrs, SkuName.standard_zrs))
+ if self.account_type in [SkuName.premium_lrs, SkuName.standard_zrs]:
+ self.fail("Storage account of type {0} cannot be changed to a type of {1} or {2}.".format(
+ self.account_dict['sku_name'], SkuName.premium_lrs, SkuName.standard_zrs))
+
+ self.results['changed'] = True
+ self.account_dict['sku_name'] = self.account_type
+
+ if self.results['changed'] and not self.check_mode:
+ # Perform the update. The API only allows changing one attribute per call.
+ try:
+ self.log("sku_name: %s" % self.account_dict['sku_name'])
+ self.log("sku_tier: %s" % self.account_dict['sku_tier'])
+ sku = self.storage_models.Sku(name=SkuName(self.account_dict['sku_name']))
+ sku.tier = self.storage_models.SkuTier(self.account_dict['sku_tier'])
+ parameters = self.storage_models.StorageAccountUpdateParameters(sku=sku)
+ self.storage_client.storage_accounts.update(self.resource_group,
+ self.name,
+ parameters)
+ except Exception as exc:
+ self.fail("Failed to update account type: {0}".format(str(exc)))
+
+ if self.custom_domain:
+ if not self.account_dict['custom_domain'] or self.account_dict['custom_domain'] != self.custom_domain:
+ self.results['changed'] = True
+ self.account_dict['custom_domain'] = self.custom_domain
+
+ if self.results['changed'] and not self.check_mode:
+ new_domain = self.storage_models.CustomDomain(name=self.custom_domain['name'],
+ use_sub_domain=self.custom_domain['use_sub_domain'])
+ parameters = self.storage_models.StorageAccountUpdateParameters(custom_domain=new_domain)
+ try:
+ self.storage_client.storage_accounts.update(self.resource_group, self.name, parameters)
+ except Exception as exc:
+ self.fail("Failed to update custom domain: {0}".format(str(exc)))
+
+ if self.access_tier:
+ if not self.account_dict['access_tier'] or self.account_dict['access_tier'] != self.access_tier:
+ self.results['changed'] = True
+ self.account_dict['access_tier'] = self.access_tier
+
+ if self.results['changed'] and not self.check_mode:
+ parameters = self.storage_models.StorageAccountUpdateParameters(access_tier=self.access_tier)
+ try:
+ self.storage_client.storage_accounts.update(self.resource_group, self.name, parameters)
+ except Exception as exc:
+ self.fail("Failed to update access tier: {0}".format(str(exc)))
+
+ update_tags, self.account_dict['tags'] = self.update_tags(self.account_dict['tags'])
+ if update_tags:
+ self.results['changed'] = True
+ if not self.check_mode:
+ parameters = self.storage_models.StorageAccountUpdateParameters(tags=self.account_dict['tags'])
+ try:
+ self.storage_client.storage_accounts.update(self.resource_group, self.name, parameters)
+ except Exception as exc:
+ self.fail("Failed to update tags: {0}".format(str(exc)))
+
+ if self.blob_cors and not compare_cors(self.account_dict.get('blob_cors', []), self.blob_cors):
+ self.results['changed'] = True
+ if not self.check_mode:
+ self.set_blob_cors()
+
+ def create_account(self):
+ self.log("Creating account {0}".format(self.name))
+
+ if not self.location:
+ self.fail('Parameter error: location required when creating a storage account.')
+
+ if not self.account_type:
+ self.fail('Parameter error: account_type required when creating a storage account.')
+
+ if not self.access_tier and self.kind == 'BlobStorage':
+ self.fail('Parameter error: access_tier required when creating a storage account of type BlobStorage.')
+
+ self.check_name_availability()
+ self.results['changed'] = True
+
+ if self.check_mode:
+ account_dict = dict(
+ location=self.location,
+ account_type=self.account_type,
+ name=self.name,
+ resource_group=self.resource_group,
+ enable_https_traffic_only=self.https_only,
+ tags=dict()
+ )
+ if self.tags:
+ account_dict['tags'] = self.tags
+ if self.blob_cors:
+ account_dict['blob_cors'] = self.blob_cors
+ return account_dict
+ sku = self.storage_models.Sku(name=self.storage_models.SkuName(self.account_type))
+ sku.tier = self.storage_models.SkuTier.standard if 'Standard' in self.account_type else \
+ self.storage_models.SkuTier.premium
+ parameters = self.storage_models.StorageAccountCreateParameters(sku=sku,
+ kind=self.kind,
+ location=self.location,
+ tags=self.tags,
+ access_tier=self.access_tier)
+ self.log(str(parameters))
+ try:
+ poller = self.storage_client.storage_accounts.create(self.resource_group, self.name, parameters)
+ self.get_poller_result(poller)
+ except CloudError as e:
+ self.log('Error creating storage account.')
+ self.fail("Failed to create account: {0}".format(str(e)))
+ if self.blob_cors:
+ self.set_blob_cors()
+ # the poller doesn't actually return anything
+ return self.get_account()
+
+ def delete_account(self):
+ if self.account_dict['provisioning_state'] == self.storage_models.ProvisioningState.succeeded.value and \
+ not self.force_delete_nonempty and self.account_has_blob_containers():
+ self.fail("Account contains blob containers. Is it in use? Use the force_delete_nonempty option to attempt deletion.")
+
+ self.log('Delete storage account {0}'.format(self.name))
+ self.results['changed'] = True
+ if not self.check_mode:
+ try:
+ status = self.storage_client.storage_accounts.delete(self.resource_group, self.name)
+ self.log("delete status: ")
+ self.log(str(status))
+ except CloudError as e:
+ self.fail("Failed to delete the account: {0}".format(str(e)))
+ return True
+
+ def account_has_blob_containers(self):
+ '''
+ If there are blob containers, then there are likely VMs depending on this account and it should
+ not be deleted.
+ '''
+ self.log('Checking for existing blob containers')
+ blob_service = self.get_blob_client(self.resource_group, self.name)
+ try:
+ response = blob_service.list_containers()
+ except AzureMissingResourceHttpError:
+ # No blob storage available?
+ return False
+
+ if len(response.items) > 0:
+ return True
+ return False
+
+ def set_blob_cors(self):
+ try:
+ cors_rules = self.storage_models.CorsRules(cors_rules=[self.storage_models.CorsRule(**x) for x in self.blob_cors])
+ self.storage_client.blob_services.set_service_properties(self.resource_group,
+ self.name,
+ self.storage_models.BlobServiceProperties(cors=cors_rules))
+ except Exception as exc:
+ self.fail("Failed to set CORS rules: {0}".format(str(exc)))
+
+
+def main():
+ AzureRMStorageAccount()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_webapp.py b/test/support/integration/plugins/modules/azure_rm_webapp.py
new file mode 100644
index 00000000..4f185f45
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_webapp.py
@@ -0,0 +1,1070 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_webapp
+version_added: "2.7"
+short_description: Manage Web App instances
+description:
+ - Create, update and delete instance of Web App.
+
+options:
+ resource_group:
+ description:
+ - Name of the resource group to which the resource belongs.
+ required: True
+ name:
+ description:
+ - Unique name of the app to create or update. To create or update a deployment slot, use the {slot} parameter.
+ required: True
+
+ location:
+ description:
+ - Resource location. If not set, location from the resource group will be used as default.
+
+ plan:
+ description:
+ - App service plan. Required for creation.
+ - Can be name of existing app service plan in same resource group as web app.
+ - Can be the resource ID of an existing app service plan. For example
+ /subscriptions/<subs_id>/resourceGroups/<resource_group>/providers/Microsoft.Web/serverFarms/<plan_name>.
+ - Can be a dict containing five parameters, defined below.
+ - C(name), name of app service plan.
+ - C(resource_group), resource group of the app service plan.
+ - C(sku), SKU of app service plan, allowed values listed on U(https://azure.microsoft.com/en-us/pricing/details/app-service/linux/).
+ - C(is_linux), whether or not the app service plan is Linux. defaults to C(False).
+ - C(number_of_workers), number of workers for app service plan.
+
+ frameworks:
+ description:
+ - Set of run time framework settings. Each setting is a dictionary.
+ - See U(https://docs.microsoft.com/en-us/azure/app-service/app-service-web-overview) for more info.
+ suboptions:
+ name:
+ description:
+ - Name of the framework.
+ - Supported framework list for Windows web app and Linux web app is different.
+ - Windows web apps support C(java), C(net_framework), C(php), C(python), and C(node) from June 2018.
+ - Windows web apps support multiple framework at the same time.
+ - Linux web apps support C(java), C(ruby), C(php), C(dotnetcore), and C(node) from June 2018.
+ - Linux web apps support only one framework.
+ - Java framework is mutually exclusive with others.
+ choices:
+ - java
+ - net_framework
+ - php
+ - python
+ - ruby
+ - dotnetcore
+ - node
+ version:
+ description:
+ - Version of the framework. For Linux web app supported value, see U(https://aka.ms/linux-stacks) for more info.
+ - C(net_framework) supported value sample, C(v4.0) for .NET 4.6 and C(v3.0) for .NET 3.5.
+ - C(php) supported value sample, C(5.5), C(5.6), C(7.0).
+ - C(python) supported value sample, C(5.5), C(5.6), C(7.0).
+ - C(node) supported value sample, C(6.6), C(6.9).
+ - C(dotnetcore) supported value sample, C(1.0), C(1.1), C(1.2).
+ - C(ruby) supported value sample, C(2.3).
+ - C(java) supported value sample, C(1.9) for Windows web app. C(1.8) for Linux web app.
+ settings:
+ description:
+ - List of settings of the framework.
+ suboptions:
+ java_container:
+ description:
+ - Name of Java container.
+ - Supported only when I(frameworks=java). Sample values C(Tomcat), C(Jetty).
+ java_container_version:
+ description:
+ - Version of Java container.
+ - Supported only when I(frameworks=java).
+ - Sample values for C(Tomcat), C(8.0), C(8.5), C(9.0). For C(Jetty,), C(9.1), C(9.3).
+
+ container_settings:
+ description:
+ - Web app container settings.
+ suboptions:
+ name:
+ description:
+ - Name of container, for example C(imagename:tag).
+ registry_server_url:
+ description:
+ - Container registry server URL, for example C(mydockerregistry.io).
+ registry_server_user:
+ description:
+ - The container registry server user name.
+ registry_server_password:
+ description:
+ - The container registry server password.
+
+ scm_type:
+ description:
+ - Repository type of deployment source, for example C(LocalGit), C(GitHub).
+ - List of supported values maintained at U(https://docs.microsoft.com/en-us/rest/api/appservice/webapps/createorupdate#scmtype).
+
+ deployment_source:
+ description:
+ - Deployment source for git.
+ suboptions:
+ url:
+ description:
+ - Repository url of deployment source.
+
+ branch:
+ description:
+ - The branch name of the repository.
+ startup_file:
+ description:
+ - The web's startup file.
+ - Used only for Linux web apps.
+
+ client_affinity_enabled:
+ description:
+ - Whether or not to send session affinity cookies, which route client requests in the same session to the same instance.
+ type: bool
+ default: True
+
+ https_only:
+ description:
+ - Configures web site to accept only https requests.
+ type: bool
+
+ dns_registration:
+ description:
+ - Whether or not the web app hostname is registered with DNS on creation. Set to C(false) to register.
+ type: bool
+
+ skip_custom_domain_verification:
+ description:
+ - Whether or not to skip verification of custom (non *.azurewebsites.net) domains associated with web app. Set to C(true) to skip.
+ type: bool
+
+ ttl_in_seconds:
+ description:
+ - Time to live in seconds for web app default domain name.
+
+ app_settings:
+ description:
+ - Configure web app application settings. Suboptions are in key value pair format.
+
+ purge_app_settings:
+ description:
+ - Purge any existing application settings. Replace web app application settings with app_settings.
+ type: bool
+
+ app_state:
+ description:
+ - Start/Stop/Restart the web app.
+ type: str
+ choices:
+ - started
+ - stopped
+ - restarted
+ default: started
+
+ state:
+ description:
+ - State of the Web App.
+ - Use C(present) to create or update a Web App and C(absent) to delete it.
+ default: present
+ choices:
+ - absent
+ - present
+
+extends_documentation_fragment:
+ - azure
+ - azure_tags
+
+author:
+ - Yunge Zhu (@yungezz)
+
+'''
+
+EXAMPLES = '''
+ - name: Create a windows web app with non-exist app service plan
+ azure_rm_webapp:
+ resource_group: myResourceGroup
+ name: myWinWebapp
+ plan:
+ resource_group: myAppServicePlan_rg
+ name: myAppServicePlan
+ is_linux: false
+ sku: S1
+
+ - name: Create a docker web app with some app settings, with docker image
+ azure_rm_webapp:
+ resource_group: myResourceGroup
+ name: myDockerWebapp
+ plan:
+ resource_group: myAppServicePlan_rg
+ name: myAppServicePlan
+ is_linux: true
+ sku: S1
+ number_of_workers: 2
+ app_settings:
+ testkey: testvalue
+ testkey2: testvalue2
+ container_settings:
+ name: ansible/ansible:ubuntu1404
+
+ - name: Create a docker web app with private acr registry
+ azure_rm_webapp:
+ resource_group: myResourceGroup
+ name: myDockerWebapp
+ plan: myAppServicePlan
+ app_settings:
+ testkey: testvalue
+ container_settings:
+ name: ansible/ubuntu1404
+ registry_server_url: myregistry.io
+ registry_server_user: user
+ registry_server_password: pass
+
+ - name: Create a linux web app with Node 6.6 framework
+ azure_rm_webapp:
+ resource_group: myResourceGroup
+ name: myLinuxWebapp
+ plan:
+ resource_group: myAppServicePlan_rg
+ name: myAppServicePlan
+ app_settings:
+ testkey: testvalue
+ frameworks:
+ - name: "node"
+ version: "6.6"
+
+ - name: Create a windows web app with node, php
+ azure_rm_webapp:
+ resource_group: myResourceGroup
+ name: myWinWebapp
+ plan:
+ resource_group: myAppServicePlan_rg
+ name: myAppServicePlan
+ app_settings:
+ testkey: testvalue
+ frameworks:
+ - name: "node"
+ version: 6.6
+ - name: "php"
+ version: "7.0"
+
+ - name: Create a stage deployment slot for an existing web app
+ azure_rm_webapp:
+ resource_group: myResourceGroup
+ name: myWebapp/slots/stage
+ plan:
+ resource_group: myAppServicePlan_rg
+ name: myAppServicePlan
+ app_settings:
+ testkey:testvalue
+
+ - name: Create a linux web app with java framework
+ azure_rm_webapp:
+ resource_group: myResourceGroup
+ name: myLinuxWebapp
+ plan:
+ resource_group: myAppServicePlan_rg
+ name: myAppServicePlan
+ app_settings:
+ testkey: testvalue
+ frameworks:
+ - name: "java"
+ version: "8"
+ settings:
+ java_container: "Tomcat"
+ java_container_version: "8.5"
+'''
+
+RETURN = '''
+azure_webapp:
+ description:
+ - ID of current web app.
+ returned: always
+ type: str
+ sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/myWebApp"
+'''
+
+import time
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.polling import LROPoller
+ from msrest.serialization import Model
+ from azure.mgmt.web.models import (
+ site_config, app_service_plan, Site,
+ AppServicePlan, SkuDescription, NameValuePair
+ )
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+container_settings_spec = dict(
+ name=dict(type='str', required=True),
+ registry_server_url=dict(type='str'),
+ registry_server_user=dict(type='str'),
+ registry_server_password=dict(type='str', no_log=True)
+)
+
+deployment_source_spec = dict(
+ url=dict(type='str'),
+ branch=dict(type='str')
+)
+
+
+framework_settings_spec = dict(
+ java_container=dict(type='str', required=True),
+ java_container_version=dict(type='str', required=True)
+)
+
+
+framework_spec = dict(
+ name=dict(
+ type='str',
+ required=True,
+ choices=['net_framework', 'java', 'php', 'node', 'python', 'dotnetcore', 'ruby']),
+ version=dict(type='str', required=True),
+ settings=dict(type='dict', options=framework_settings_spec)
+)
+
+
+def _normalize_sku(sku):
+ if sku is None:
+ return sku
+
+ sku = sku.upper()
+ if sku == 'FREE':
+ return 'F1'
+ elif sku == 'SHARED':
+ return 'D1'
+ return sku
+
+
+def get_sku_name(tier):
+ tier = tier.upper()
+ if tier == 'F1' or tier == "FREE":
+ return 'FREE'
+ elif tier == 'D1' or tier == "SHARED":
+ return 'SHARED'
+ elif tier in ['B1', 'B2', 'B3', 'BASIC']:
+ return 'BASIC'
+ elif tier in ['S1', 'S2', 'S3']:
+ return 'STANDARD'
+ elif tier in ['P1', 'P2', 'P3']:
+ return 'PREMIUM'
+ elif tier in ['P1V2', 'P2V2', 'P3V2']:
+ return 'PREMIUMV2'
+ else:
+ return None
+
+
+def appserviceplan_to_dict(plan):
+ return dict(
+ id=plan.id,
+ name=plan.name,
+ kind=plan.kind,
+ location=plan.location,
+ reserved=plan.reserved,
+ is_linux=plan.reserved,
+ provisioning_state=plan.provisioning_state,
+ tags=plan.tags if plan.tags else None
+ )
+
+
+def webapp_to_dict(webapp):
+ return dict(
+ id=webapp.id,
+ name=webapp.name,
+ location=webapp.location,
+ client_cert_enabled=webapp.client_cert_enabled,
+ enabled=webapp.enabled,
+ reserved=webapp.reserved,
+ client_affinity_enabled=webapp.client_affinity_enabled,
+ server_farm_id=webapp.server_farm_id,
+ host_names_disabled=webapp.host_names_disabled,
+ https_only=webapp.https_only if hasattr(webapp, 'https_only') else None,
+ skip_custom_domain_verification=webapp.skip_custom_domain_verification if hasattr(webapp, 'skip_custom_domain_verification') else None,
+ ttl_in_seconds=webapp.ttl_in_seconds if hasattr(webapp, 'ttl_in_seconds') else None,
+ state=webapp.state,
+ tags=webapp.tags if webapp.tags else None
+ )
+
+
+class Actions:
+ CreateOrUpdate, UpdateAppSettings, Delete = range(3)
+
+
+class AzureRMWebApps(AzureRMModuleBase):
+ """Configuration class for an Azure RM Web App resource"""
+
+ def __init__(self):
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str',
+ required=True
+ ),
+ location=dict(
+ type='str'
+ ),
+ plan=dict(
+ type='raw'
+ ),
+ frameworks=dict(
+ type='list',
+ elements='dict',
+ options=framework_spec
+ ),
+ container_settings=dict(
+ type='dict',
+ options=container_settings_spec
+ ),
+ scm_type=dict(
+ type='str',
+ ),
+ deployment_source=dict(
+ type='dict',
+ options=deployment_source_spec
+ ),
+ startup_file=dict(
+ type='str'
+ ),
+ client_affinity_enabled=dict(
+ type='bool',
+ default=True
+ ),
+ dns_registration=dict(
+ type='bool'
+ ),
+ https_only=dict(
+ type='bool'
+ ),
+ skip_custom_domain_verification=dict(
+ type='bool'
+ ),
+ ttl_in_seconds=dict(
+ type='int'
+ ),
+ app_settings=dict(
+ type='dict'
+ ),
+ purge_app_settings=dict(
+ type='bool',
+ default=False
+ ),
+ app_state=dict(
+ type='str',
+ choices=['started', 'stopped', 'restarted'],
+ default='started'
+ ),
+ state=dict(
+ type='str',
+ default='present',
+ choices=['present', 'absent']
+ )
+ )
+
+ mutually_exclusive = [['container_settings', 'frameworks']]
+
+ self.resource_group = None
+ self.name = None
+ self.location = None
+
+ # update in create_or_update as parameters
+ self.client_affinity_enabled = True
+ self.dns_registration = None
+ self.skip_custom_domain_verification = None
+ self.ttl_in_seconds = None
+ self.https_only = None
+
+ self.tags = None
+
+ # site config, e.g app settings, ssl
+ self.site_config = dict()
+ self.app_settings = dict()
+ self.app_settings_strDic = None
+
+ # app service plan
+ self.plan = None
+
+ # siteSourceControl
+ self.deployment_source = dict()
+
+ # site, used at level creation, or update. e.g windows/linux, client_affinity etc first level args
+ self.site = None
+
+ # property for internal usage, not used for sdk
+ self.container_settings = None
+
+ self.purge_app_settings = False
+ self.app_state = 'started'
+
+ self.results = dict(
+ changed=False,
+ id=None,
+ )
+ self.state = None
+ self.to_do = []
+
+ self.frameworks = None
+
+ # set site_config value from kwargs
+ self.site_config_updatable_properties = ["net_framework_version",
+ "java_version",
+ "php_version",
+ "python_version",
+ "scm_type"]
+
+ # updatable_properties
+ self.updatable_properties = ["client_affinity_enabled",
+ "force_dns_registration",
+ "https_only",
+ "skip_custom_domain_verification",
+ "ttl_in_seconds"]
+
+ self.supported_linux_frameworks = ['ruby', 'php', 'dotnetcore', 'node', 'java']
+ self.supported_windows_frameworks = ['net_framework', 'php', 'python', 'node', 'java']
+
+ super(AzureRMWebApps, self).__init__(derived_arg_spec=self.module_arg_spec,
+ mutually_exclusive=mutually_exclusive,
+ supports_check_mode=True,
+ supports_tags=True)
+
+ def exec_module(self, **kwargs):
+ """Main module execution method"""
+
+ for key in list(self.module_arg_spec.keys()) + ['tags']:
+ if hasattr(self, key):
+ setattr(self, key, kwargs[key])
+ elif kwargs[key] is not None:
+ if key == "scm_type":
+ self.site_config[key] = kwargs[key]
+
+ old_response = None
+ response = None
+ to_be_updated = False
+
+ # set location
+ resource_group = self.get_resource_group(self.resource_group)
+ if not self.location:
+ self.location = resource_group.location
+
+ # get existing web app
+ old_response = self.get_webapp()
+
+ if old_response:
+ self.results['id'] = old_response['id']
+
+ if self.state == 'present':
+ if not self.plan and not old_response:
+ self.fail("Please specify plan for newly created web app.")
+
+ if not self.plan:
+ self.plan = old_response['server_farm_id']
+
+ self.plan = self.parse_resource_to_dict(self.plan)
+
+ # get app service plan
+ is_linux = False
+ old_plan = self.get_app_service_plan()
+ if old_plan:
+ is_linux = old_plan['reserved']
+ else:
+ is_linux = self.plan['is_linux'] if 'is_linux' in self.plan else False
+
+ if self.frameworks:
+ # java is mutually exclusive with other frameworks
+ if len(self.frameworks) > 1 and any(f['name'] == 'java' for f in self.frameworks):
+ self.fail('Java is mutually exclusive with other frameworks.')
+
+ if is_linux:
+ if len(self.frameworks) != 1:
+ self.fail('Can specify one framework only for Linux web app.')
+
+ if self.frameworks[0]['name'] not in self.supported_linux_frameworks:
+ self.fail('Unsupported framework {0} for Linux web app.'.format(self.frameworks[0]['name']))
+
+ self.site_config['linux_fx_version'] = (self.frameworks[0]['name'] + '|' + self.frameworks[0]['version']).upper()
+
+ if self.frameworks[0]['name'] == 'java':
+ if self.frameworks[0]['version'] != '8':
+ self.fail("Linux web app only supports java 8.")
+ if self.frameworks[0]['settings'] and self.frameworks[0]['settings']['java_container'].lower() != 'tomcat':
+ self.fail("Linux web app only supports tomcat container.")
+
+ if self.frameworks[0]['settings'] and self.frameworks[0]['settings']['java_container'].lower() == 'tomcat':
+ self.site_config['linux_fx_version'] = 'TOMCAT|' + self.frameworks[0]['settings']['java_container_version'] + '-jre8'
+ else:
+ self.site_config['linux_fx_version'] = 'JAVA|8-jre8'
+ else:
+ for fx in self.frameworks:
+ if fx.get('name') not in self.supported_windows_frameworks:
+ self.fail('Unsupported framework {0} for Windows web app.'.format(fx.get('name')))
+ else:
+ self.site_config[fx.get('name') + '_version'] = fx.get('version')
+
+ if 'settings' in fx and fx['settings'] is not None:
+ for key, value in fx['settings'].items():
+ self.site_config[key] = value
+
+ if not self.app_settings:
+ self.app_settings = dict()
+
+ if self.container_settings:
+ linux_fx_version = 'DOCKER|'
+
+ if self.container_settings.get('registry_server_url'):
+ self.app_settings['DOCKER_REGISTRY_SERVER_URL'] = 'https://' + self.container_settings['registry_server_url']
+
+ linux_fx_version += self.container_settings['registry_server_url'] + '/'
+
+ linux_fx_version += self.container_settings['name']
+
+ self.site_config['linux_fx_version'] = linux_fx_version
+
+ if self.container_settings.get('registry_server_user'):
+ self.app_settings['DOCKER_REGISTRY_SERVER_USERNAME'] = self.container_settings['registry_server_user']
+
+ if self.container_settings.get('registry_server_password'):
+ self.app_settings['DOCKER_REGISTRY_SERVER_PASSWORD'] = self.container_settings['registry_server_password']
+
+ # init site
+ self.site = Site(location=self.location, site_config=self.site_config)
+
+ if self.https_only is not None:
+ self.site.https_only = self.https_only
+
+ if self.client_affinity_enabled:
+ self.site.client_affinity_enabled = self.client_affinity_enabled
+
+ # check if the web app already present in the resource group
+ if not old_response:
+ self.log("Web App instance doesn't exist")
+
+ to_be_updated = True
+ self.to_do.append(Actions.CreateOrUpdate)
+ self.site.tags = self.tags
+
+ # service plan is required for creation
+ if not self.plan:
+ self.fail("Please specify app service plan in plan parameter.")
+
+ if not old_plan:
+ # no existing service plan, create one
+ if (not self.plan.get('name') or not self.plan.get('sku')):
+ self.fail('Please specify name, is_linux, sku in plan')
+
+ if 'location' not in self.plan:
+ plan_resource_group = self.get_resource_group(self.plan['resource_group'])
+ self.plan['location'] = plan_resource_group.location
+
+ old_plan = self.create_app_service_plan()
+
+ self.site.server_farm_id = old_plan['id']
+
+ # if linux, setup startup_file
+ if old_plan['is_linux']:
+ if hasattr(self, 'startup_file'):
+ self.site_config['app_command_line'] = self.startup_file
+
+ # set app setting
+ if self.app_settings:
+ app_settings = []
+ for key in self.app_settings.keys():
+ app_settings.append(NameValuePair(name=key, value=self.app_settings[key]))
+
+ self.site_config['app_settings'] = app_settings
+ else:
+ # existing web app, do update
+ self.log("Web App instance already exists")
+
+ self.log('Result: {0}'.format(old_response))
+
+ update_tags, self.site.tags = self.update_tags(old_response.get('tags', None))
+
+ if update_tags:
+ to_be_updated = True
+
+ # check if root level property changed
+ if self.is_updatable_property_changed(old_response):
+ to_be_updated = True
+ self.to_do.append(Actions.CreateOrUpdate)
+
+ # check if site_config changed
+ old_config = self.get_webapp_configuration()
+
+ if self.is_site_config_changed(old_config):
+ to_be_updated = True
+ self.to_do.append(Actions.CreateOrUpdate)
+
+ # check if linux_fx_version changed
+ if old_config.linux_fx_version != self.site_config.get('linux_fx_version', ''):
+ to_be_updated = True
+ self.to_do.append(Actions.CreateOrUpdate)
+
+ self.app_settings_strDic = self.list_app_settings()
+
+ # purge existing app_settings:
+ if self.purge_app_settings:
+ to_be_updated = True
+ self.app_settings_strDic = dict()
+ self.to_do.append(Actions.UpdateAppSettings)
+
+ # check if app settings changed
+ if self.purge_app_settings or self.is_app_settings_changed():
+ to_be_updated = True
+ self.to_do.append(Actions.UpdateAppSettings)
+
+ if self.app_settings:
+ for key in self.app_settings.keys():
+ self.app_settings_strDic[key] = self.app_settings[key]
+
+ elif self.state == 'absent':
+ if old_response:
+ self.log("Delete Web App instance")
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ self.delete_webapp()
+
+ self.log('Web App instance deleted')
+
+ else:
+ self.fail("Web app {0} not exists.".format(self.name))
+
+ if to_be_updated:
+ self.log('Need to Create/Update web app')
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ if Actions.CreateOrUpdate in self.to_do:
+ response = self.create_update_webapp()
+
+ self.results['id'] = response['id']
+
+ if Actions.UpdateAppSettings in self.to_do:
+ update_response = self.update_app_settings()
+ self.results['id'] = update_response.id
+
+ webapp = None
+ if old_response:
+ webapp = old_response
+ if response:
+ webapp = response
+
+ if webapp:
+ if (webapp['state'] != 'Stopped' and self.app_state == 'stopped') or \
+ (webapp['state'] != 'Running' and self.app_state == 'started') or \
+ self.app_state == 'restarted':
+
+ self.results['changed'] = True
+ if self.check_mode:
+ return self.results
+
+ self.set_webapp_state(self.app_state)
+
+ return self.results
+
+ # compare existing web app with input, determine weather it's update operation
+ def is_updatable_property_changed(self, existing_webapp):
+ for property_name in self.updatable_properties:
+ if hasattr(self, property_name) and getattr(self, property_name) is not None and \
+ getattr(self, property_name) != existing_webapp.get(property_name, None):
+ return True
+
+ return False
+
+ # compare xxx_version
+ def is_site_config_changed(self, existing_config):
+ for fx_version in self.site_config_updatable_properties:
+ if self.site_config.get(fx_version):
+ if not getattr(existing_config, fx_version) or \
+ getattr(existing_config, fx_version).upper() != self.site_config.get(fx_version).upper():
+ return True
+
+ return False
+
+ # comparing existing app setting with input, determine whether it's changed
+ def is_app_settings_changed(self):
+ if self.app_settings:
+ if self.app_settings_strDic:
+ for key in self.app_settings.keys():
+ if self.app_settings[key] != self.app_settings_strDic.get(key, None):
+ return True
+ else:
+ return True
+ return False
+
+ # comparing deployment source with input, determine wheather it's changed
+ def is_deployment_source_changed(self, existing_webapp):
+ if self.deployment_source:
+ if self.deployment_source.get('url') \
+ and self.deployment_source['url'] != existing_webapp.get('site_source_control')['url']:
+ return True
+
+ if self.deployment_source.get('branch') \
+ and self.deployment_source['branch'] != existing_webapp.get('site_source_control')['branch']:
+ return True
+
+ return False
+
+ def create_update_webapp(self):
+ '''
+ Creates or updates Web App with the specified configuration.
+
+ :return: deserialized Web App instance state dictionary
+ '''
+ self.log(
+ "Creating / Updating the Web App instance {0}".format(self.name))
+
+ try:
+ skip_dns_registration = self.dns_registration
+ force_dns_registration = None if self.dns_registration is None else not self.dns_registration
+
+ response = self.web_client.web_apps.create_or_update(resource_group_name=self.resource_group,
+ name=self.name,
+ site_envelope=self.site,
+ skip_dns_registration=skip_dns_registration,
+ skip_custom_domain_verification=self.skip_custom_domain_verification,
+ force_dns_registration=force_dns_registration,
+ ttl_in_seconds=self.ttl_in_seconds)
+ if isinstance(response, LROPoller):
+ response = self.get_poller_result(response)
+
+ except CloudError as exc:
+ self.log('Error attempting to create the Web App instance.')
+ self.fail(
+ "Error creating the Web App instance: {0}".format(str(exc)))
+ return webapp_to_dict(response)
+
+ def delete_webapp(self):
+ '''
+ Deletes specified Web App instance in the specified subscription and resource group.
+
+ :return: True
+ '''
+ self.log("Deleting the Web App instance {0}".format(self.name))
+ try:
+ response = self.web_client.web_apps.delete(resource_group_name=self.resource_group,
+ name=self.name)
+ except CloudError as e:
+ self.log('Error attempting to delete the Web App instance.')
+ self.fail(
+ "Error deleting the Web App instance: {0}".format(str(e)))
+
+ return True
+
+ def get_webapp(self):
+ '''
+ Gets the properties of the specified Web App.
+
+ :return: deserialized Web App instance state dictionary
+ '''
+ self.log(
+ "Checking if the Web App instance {0} is present".format(self.name))
+
+ response = None
+
+ try:
+ response = self.web_client.web_apps.get(resource_group_name=self.resource_group,
+ name=self.name)
+
+ # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError
+ if response is not None:
+ self.log("Response : {0}".format(response))
+ self.log("Web App instance : {0} found".format(response.name))
+ return webapp_to_dict(response)
+
+ except CloudError as ex:
+ pass
+
+ self.log("Didn't find web app {0} in resource group {1}".format(
+ self.name, self.resource_group))
+
+ return False
+
+ def get_app_service_plan(self):
+ '''
+ Gets app service plan
+ :return: deserialized app service plan dictionary
+ '''
+ self.log("Get App Service Plan {0}".format(self.plan['name']))
+
+ try:
+ response = self.web_client.app_service_plans.get(
+ resource_group_name=self.plan['resource_group'],
+ name=self.plan['name'])
+
+ # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError
+ if response is not None:
+ self.log("Response : {0}".format(response))
+ self.log("App Service Plan : {0} found".format(response.name))
+
+ return appserviceplan_to_dict(response)
+ except CloudError as ex:
+ pass
+
+ self.log("Didn't find app service plan {0} in resource group {1}".format(
+ self.plan['name'], self.plan['resource_group']))
+
+ return False
+
+ def create_app_service_plan(self):
+ '''
+ Creates app service plan
+ :return: deserialized app service plan dictionary
+ '''
+ self.log("Create App Service Plan {0}".format(self.plan['name']))
+
+ try:
+ # normalize sku
+ sku = _normalize_sku(self.plan['sku'])
+
+ sku_def = SkuDescription(tier=get_sku_name(
+ sku), name=sku, capacity=(self.plan.get('number_of_workers', None)))
+ plan_def = AppServicePlan(
+ location=self.plan['location'], app_service_plan_name=self.plan['name'], sku=sku_def, reserved=(self.plan.get('is_linux', None)))
+
+ poller = self.web_client.app_service_plans.create_or_update(
+ self.plan['resource_group'], self.plan['name'], plan_def)
+
+ if isinstance(poller, LROPoller):
+ response = self.get_poller_result(poller)
+
+ self.log("Response : {0}".format(response))
+
+ return appserviceplan_to_dict(response)
+ except CloudError as ex:
+ self.fail("Failed to create app service plan {0} in resource group {1}: {2}".format(
+ self.plan['name'], self.plan['resource_group'], str(ex)))
+
+ def list_app_settings(self):
+ '''
+ List application settings
+ :return: deserialized list response
+ '''
+ self.log("List application setting")
+
+ try:
+
+ response = self.web_client.web_apps.list_application_settings(
+ resource_group_name=self.resource_group, name=self.name)
+ self.log("Response : {0}".format(response))
+
+ return response.properties
+ except CloudError as ex:
+ self.fail("Failed to list application settings for web app {0} in resource group {1}: {2}".format(
+ self.name, self.resource_group, str(ex)))
+
+ def update_app_settings(self):
+ '''
+ Update application settings
+ :return: deserialized updating response
+ '''
+ self.log("Update application setting")
+
+ try:
+ response = self.web_client.web_apps.update_application_settings(
+ resource_group_name=self.resource_group, name=self.name, properties=self.app_settings_strDic)
+ self.log("Response : {0}".format(response))
+
+ return response
+ except CloudError as ex:
+ self.fail("Failed to update application settings for web app {0} in resource group {1}: {2}".format(
+ self.name, self.resource_group, str(ex)))
+
+ def create_or_update_source_control(self):
+ '''
+ Update site source control
+ :return: deserialized updating response
+ '''
+ self.log("Update site source control")
+
+ if self.deployment_source is None:
+ return False
+
+ self.deployment_source['is_manual_integration'] = False
+ self.deployment_source['is_mercurial'] = False
+
+ try:
+ response = self.web_client.web_client.create_or_update_source_control(
+ self.resource_group, self.name, self.deployment_source)
+ self.log("Response : {0}".format(response))
+
+ return response.as_dict()
+ except CloudError as ex:
+ self.fail("Failed to update site source control for web app {0} in resource group {1}".format(
+ self.name, self.resource_group))
+
+ def get_webapp_configuration(self):
+ '''
+ Get web app configuration
+ :return: deserialized web app configuration response
+ '''
+ self.log("Get web app configuration")
+
+ try:
+
+ response = self.web_client.web_apps.get_configuration(
+ resource_group_name=self.resource_group, name=self.name)
+ self.log("Response : {0}".format(response))
+
+ return response
+ except CloudError as ex:
+ self.log("Failed to get configuration for web app {0} in resource group {1}: {2}".format(
+ self.name, self.resource_group, str(ex)))
+
+ return False
+
+ def set_webapp_state(self, appstate):
+ '''
+ Start/stop/restart web app
+ :return: deserialized updating response
+ '''
+ try:
+ if appstate == 'started':
+ response = self.web_client.web_apps.start(resource_group_name=self.resource_group, name=self.name)
+ elif appstate == 'stopped':
+ response = self.web_client.web_apps.stop(resource_group_name=self.resource_group, name=self.name)
+ elif appstate == 'restarted':
+ response = self.web_client.web_apps.restart(resource_group_name=self.resource_group, name=self.name)
+ else:
+ self.fail("Invalid web app state {0}".format(appstate))
+
+ self.log("Response : {0}".format(response))
+
+ return response
+ except CloudError as ex:
+ request_id = ex.request_id if ex.request_id else ''
+ self.log("Failed to {0} web app {1} in resource group {2}, request_id {3} - {4}".format(
+ appstate, self.name, self.resource_group, request_id, str(ex)))
+
+
+def main():
+ """Main execution"""
+ AzureRMWebApps()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_webapp_info.py b/test/support/integration/plugins/modules/azure_rm_webapp_info.py
new file mode 100644
index 00000000..22286803
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_webapp_info.py
@@ -0,0 +1,489 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_webapp_info
+
+version_added: "2.9"
+
+short_description: Get Azure web app facts
+
+description:
+ - Get facts for a specific web app or all web app in a resource group, or all web app in current subscription.
+
+options:
+ name:
+ description:
+ - Only show results for a specific web app.
+ resource_group:
+ description:
+ - Limit results by resource group.
+ return_publish_profile:
+ description:
+ - Indicate whether to return publishing profile of the web app.
+ default: False
+ type: bool
+ tags:
+ description:
+ - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'.
+
+extends_documentation_fragment:
+ - azure
+
+author:
+ - Yunge Zhu (@yungezz)
+'''
+
+EXAMPLES = '''
+ - name: Get facts for web app by name
+ azure_rm_webapp_info:
+ resource_group: myResourceGroup
+ name: winwebapp1
+
+ - name: Get facts for web apps in resource group
+ azure_rm_webapp_info:
+ resource_group: myResourceGroup
+
+ - name: Get facts for web apps with tags
+ azure_rm_webapp_info:
+ tags:
+ - testtag
+ - foo:bar
+'''
+
+RETURN = '''
+webapps:
+ description:
+ - List of web apps.
+ returned: always
+ type: complex
+ contains:
+ id:
+ description:
+ - ID of the web app.
+ returned: always
+ type: str
+ sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/myWebApp
+ name:
+ description:
+ - Name of the web app.
+ returned: always
+ type: str
+ sample: winwebapp1
+ resource_group:
+ description:
+ - Resource group of the web app.
+ returned: always
+ type: str
+ sample: myResourceGroup
+ location:
+ description:
+ - Location of the web app.
+ returned: always
+ type: str
+ sample: eastus
+ plan:
+ description:
+ - ID of app service plan used by the web app.
+ returned: always
+ type: str
+ sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/serverfarms/myAppServicePlan
+ app_settings:
+ description:
+ - App settings of the application. Only returned when web app has app settings.
+ returned: always
+ type: dict
+ sample: {
+ "testkey": "testvalue",
+ "testkey2": "testvalue2"
+ }
+ frameworks:
+ description:
+ - Frameworks of the application. Only returned when web app has frameworks.
+ returned: always
+ type: list
+ sample: [
+ {
+ "name": "net_framework",
+ "version": "v4.0"
+ },
+ {
+ "name": "java",
+ "settings": {
+ "java_container": "tomcat",
+ "java_container_version": "8.5"
+ },
+ "version": "1.7"
+ },
+ {
+ "name": "php",
+ "version": "5.6"
+ }
+ ]
+ availability_state:
+ description:
+ - Availability of this web app.
+ returned: always
+ type: str
+ sample: Normal
+ default_host_name:
+ description:
+ - Host name of the web app.
+ returned: always
+ type: str
+ sample: vxxisurg397winapp4.azurewebsites.net
+ enabled:
+ description:
+ - Indicates the web app enabled or not.
+ returned: always
+ type: bool
+ sample: true
+ enabled_host_names:
+ description:
+ - Enabled host names of the web app.
+ returned: always
+ type: list
+ sample: [
+ "vxxisurg397winapp4.azurewebsites.net",
+ "vxxisurg397winapp4.scm.azurewebsites.net"
+ ]
+ host_name_ssl_states:
+ description:
+ - SSL state per host names of the web app.
+ returned: always
+ type: list
+ sample: [
+ {
+ "hostType": "Standard",
+ "name": "vxxisurg397winapp4.azurewebsites.net",
+ "sslState": "Disabled"
+ },
+ {
+ "hostType": "Repository",
+ "name": "vxxisurg397winapp4.scm.azurewebsites.net",
+ "sslState": "Disabled"
+ }
+ ]
+ host_names:
+ description:
+ - Host names of the web app.
+ returned: always
+ type: list
+ sample: [
+ "vxxisurg397winapp4.azurewebsites.net"
+ ]
+ outbound_ip_addresses:
+ description:
+ - Outbound IP address of the web app.
+ returned: always
+ type: str
+ sample: "40.71.11.131,40.85.166.200,168.62.166.67,137.135.126.248,137.135.121.45"
+ ftp_publish_url:
+ description:
+ - Publishing URL of the web app when deployment type is FTP.
+ returned: always
+ type: str
+ sample: ftp://xxxx.ftp.azurewebsites.windows.net
+ state:
+ description:
+ - State of the web app.
+ returned: always
+ type: str
+ sample: running
+ publishing_username:
+ description:
+ - Publishing profile user name.
+ returned: only when I(return_publish_profile=True).
+ type: str
+ sample: "$vxxisuRG397winapp4"
+ publishing_password:
+ description:
+ - Publishing profile password.
+ returned: only when I(return_publish_profile=True).
+ type: str
+ sample: "uvANsPQpGjWJmrFfm4Ssd5rpBSqGhjMk11pMSgW2vCsQtNx9tcgZ0xN26s9A"
+ tags:
+ description:
+ - Tags assigned to the resource. Dictionary of string:string pairs.
+ returned: always
+ type: dict
+ sample: { tag1: abc }
+'''
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.polling import LROPoller
+ from azure.common import AzureMissingResourceHttpError, AzureHttpError
+except Exception:
+ # This is handled in azure_rm_common
+ pass
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+AZURE_OBJECT_CLASS = 'WebApp'
+
+
+class AzureRMWebAppInfo(AzureRMModuleBase):
+
+ def __init__(self):
+
+ self.module_arg_spec = dict(
+ name=dict(type='str'),
+ resource_group=dict(type='str'),
+ tags=dict(type='list'),
+ return_publish_profile=dict(type='bool', default=False),
+ )
+
+ self.results = dict(
+ changed=False,
+ webapps=[],
+ )
+
+ self.name = None
+ self.resource_group = None
+ self.tags = None
+ self.return_publish_profile = False
+
+ self.framework_names = ['net_framework', 'java', 'php', 'node', 'python', 'dotnetcore', 'ruby']
+
+ super(AzureRMWebAppInfo, self).__init__(self.module_arg_spec,
+ supports_tags=False,
+ facts_module=True)
+
+ def exec_module(self, **kwargs):
+ is_old_facts = self.module._name == 'azure_rm_webapp_facts'
+ if is_old_facts:
+ self.module.deprecate("The 'azure_rm_webapp_facts' module has been renamed to 'azure_rm_webapp_info'",
+ version='2.13', collection_name='ansible.builtin')
+
+ for key in self.module_arg_spec:
+ setattr(self, key, kwargs[key])
+
+ if self.name:
+ self.results['webapps'] = self.list_by_name()
+ elif self.resource_group:
+ self.results['webapps'] = self.list_by_resource_group()
+ else:
+ self.results['webapps'] = self.list_all()
+
+ return self.results
+
+ def list_by_name(self):
+ self.log('Get web app {0}'.format(self.name))
+ item = None
+ result = []
+
+ try:
+ item = self.web_client.web_apps.get(self.resource_group, self.name)
+ except CloudError:
+ pass
+
+ if item and self.has_tags(item.tags, self.tags):
+ curated_result = self.get_curated_webapp(self.resource_group, self.name, item)
+ result = [curated_result]
+
+ return result
+
+ def list_by_resource_group(self):
+ self.log('List web apps in resource groups {0}'.format(self.resource_group))
+ try:
+ response = list(self.web_client.web_apps.list_by_resource_group(self.resource_group))
+ except CloudError as exc:
+ request_id = exc.request_id if exc.request_id else ''
+ self.fail("Error listing web apps in resource groups {0}, request id: {1} - {2}".format(self.resource_group, request_id, str(exc)))
+
+ results = []
+ for item in response:
+ if self.has_tags(item.tags, self.tags):
+ curated_output = self.get_curated_webapp(self.resource_group, item.name, item)
+ results.append(curated_output)
+ return results
+
+ def list_all(self):
+ self.log('List web apps in current subscription')
+ try:
+ response = list(self.web_client.web_apps.list())
+ except CloudError as exc:
+ request_id = exc.request_id if exc.request_id else ''
+ self.fail("Error listing web apps, request id {0} - {1}".format(request_id, str(exc)))
+
+ results = []
+ for item in response:
+ if self.has_tags(item.tags, self.tags):
+ curated_output = self.get_curated_webapp(item.resource_group, item.name, item)
+ results.append(curated_output)
+ return results
+
+ def list_webapp_configuration(self, resource_group, name):
+ self.log('Get web app {0} configuration'.format(name))
+
+ response = []
+
+ try:
+ response = self.web_client.web_apps.get_configuration(resource_group_name=resource_group, name=name)
+ except CloudError as ex:
+ request_id = ex.request_id if ex.request_id else ''
+ self.fail('Error getting web app {0} configuration, request id {1} - {2}'.format(name, request_id, str(ex)))
+
+ return response.as_dict()
+
+ def list_webapp_appsettings(self, resource_group, name):
+ self.log('Get web app {0} app settings'.format(name))
+
+ response = []
+
+ try:
+ response = self.web_client.web_apps.list_application_settings(resource_group_name=resource_group, name=name)
+ except CloudError as ex:
+ request_id = ex.request_id if ex.request_id else ''
+ self.fail('Error getting web app {0} app settings, request id {1} - {2}'.format(name, request_id, str(ex)))
+
+ return response.as_dict()
+
+ def get_publish_credentials(self, resource_group, name):
+ self.log('Get web app {0} publish credentials'.format(name))
+ try:
+ poller = self.web_client.web_apps.list_publishing_credentials(resource_group, name)
+ if isinstance(poller, LROPoller):
+ response = self.get_poller_result(poller)
+ except CloudError as ex:
+ request_id = ex.request_id if ex.request_id else ''
+ self.fail('Error getting web app {0} publishing credentials - {1}'.format(request_id, str(ex)))
+ return response
+
+ def get_webapp_ftp_publish_url(self, resource_group, name):
+ import xmltodict
+
+ self.log('Get web app {0} app publish profile'.format(name))
+
+ url = None
+ try:
+ content = self.web_client.web_apps.list_publishing_profile_xml_with_secrets(resource_group_name=resource_group, name=name)
+ if not content:
+ return url
+
+ full_xml = ''
+ for f in content:
+ full_xml += f.decode()
+ profiles = xmltodict.parse(full_xml, xml_attribs=True)['publishData']['publishProfile']
+
+ if not profiles:
+ return url
+
+ for profile in profiles:
+ if profile['@publishMethod'] == 'FTP':
+ url = profile['@publishUrl']
+
+ except CloudError as ex:
+ self.fail('Error getting web app {0} app settings'.format(name))
+
+ return url
+
+ def get_curated_webapp(self, resource_group, name, webapp):
+ pip = self.serialize_obj(webapp, AZURE_OBJECT_CLASS)
+
+ try:
+ site_config = self.list_webapp_configuration(resource_group, name)
+ app_settings = self.list_webapp_appsettings(resource_group, name)
+ publish_cred = self.get_publish_credentials(resource_group, name)
+ ftp_publish_url = self.get_webapp_ftp_publish_url(resource_group, name)
+ except CloudError as ex:
+ pass
+ return self.construct_curated_webapp(webapp=pip,
+ configuration=site_config,
+ app_settings=app_settings,
+ deployment_slot=None,
+ ftp_publish_url=ftp_publish_url,
+ publish_credentials=publish_cred)
+
+ def construct_curated_webapp(self,
+ webapp,
+ configuration=None,
+ app_settings=None,
+ deployment_slot=None,
+ ftp_publish_url=None,
+ publish_credentials=None):
+ curated_output = dict()
+ curated_output['id'] = webapp['id']
+ curated_output['name'] = webapp['name']
+ curated_output['resource_group'] = webapp['properties']['resourceGroup']
+ curated_output['location'] = webapp['location']
+ curated_output['plan'] = webapp['properties']['serverFarmId']
+ curated_output['tags'] = webapp.get('tags', None)
+
+ # important properties from output. not match input arguments.
+ curated_output['app_state'] = webapp['properties']['state']
+ curated_output['availability_state'] = webapp['properties']['availabilityState']
+ curated_output['default_host_name'] = webapp['properties']['defaultHostName']
+ curated_output['host_names'] = webapp['properties']['hostNames']
+ curated_output['enabled'] = webapp['properties']['enabled']
+ curated_output['enabled_host_names'] = webapp['properties']['enabledHostNames']
+ curated_output['host_name_ssl_states'] = webapp['properties']['hostNameSslStates']
+ curated_output['outbound_ip_addresses'] = webapp['properties']['outboundIpAddresses']
+
+ # curated site_config
+ if configuration:
+ curated_output['frameworks'] = []
+ for fx_name in self.framework_names:
+ fx_version = configuration.get(fx_name + '_version', None)
+ if fx_version:
+ fx = {
+ 'name': fx_name,
+ 'version': fx_version
+ }
+ # java container setting
+ if fx_name == 'java':
+ if configuration['java_container'] and configuration['java_container_version']:
+ settings = {
+ 'java_container': configuration['java_container'].lower(),
+ 'java_container_version': configuration['java_container_version']
+ }
+ fx['settings'] = settings
+
+ curated_output['frameworks'].append(fx)
+
+ # linux_fx_version
+ if configuration.get('linux_fx_version', None):
+ tmp = configuration.get('linux_fx_version').split("|")
+ if len(tmp) == 2:
+ curated_output['frameworks'].append({'name': tmp[0].lower(), 'version': tmp[1]})
+
+ # curated app_settings
+ if app_settings and app_settings.get('properties', None):
+ curated_output['app_settings'] = dict()
+ for item in app_settings['properties']:
+ curated_output['app_settings'][item] = app_settings['properties'][item]
+
+ # curated deploymenet_slot
+ if deployment_slot:
+ curated_output['deployment_slot'] = deployment_slot
+
+ # ftp_publish_url
+ if ftp_publish_url:
+ curated_output['ftp_publish_url'] = ftp_publish_url
+
+ # curated publish credentials
+ if publish_credentials and self.return_publish_profile:
+ curated_output['publishing_username'] = publish_credentials.publishing_user_name
+ curated_output['publishing_password'] = publish_credentials.publishing_password
+ return curated_output
+
+
+def main():
+ AzureRMWebAppInfo()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/azure_rm_webappslot.py b/test/support/integration/plugins/modules/azure_rm_webappslot.py
new file mode 100644
index 00000000..ddba710b
--- /dev/null
+++ b/test/support/integration/plugins/modules/azure_rm_webappslot.py
@@ -0,0 +1,1058 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: azure_rm_webappslot
+version_added: "2.8"
+short_description: Manage Azure Web App slot
+description:
+ - Create, update and delete Azure Web App slot.
+
+options:
+ resource_group:
+ description:
+ - Name of the resource group to which the resource belongs.
+ required: True
+ name:
+ description:
+ - Unique name of the deployment slot to create or update.
+ required: True
+ webapp_name:
+ description:
+ - Web app name which this deployment slot belongs to.
+ required: True
+ location:
+ description:
+ - Resource location. If not set, location from the resource group will be used as default.
+ configuration_source:
+ description:
+ - Source slot to clone configurations from when creating slot. Use webapp's name to refer to the production slot.
+ auto_swap_slot_name:
+ description:
+ - Used to configure target slot name to auto swap, or disable auto swap.
+ - Set it target slot name to auto swap.
+ - Set it to False to disable auto slot swap.
+ swap:
+ description:
+ - Swap deployment slots of a web app.
+ suboptions:
+ action:
+ description:
+ - Swap types.
+ - C(preview) is to apply target slot settings on source slot first.
+ - C(swap) is to complete swapping.
+ - C(reset) is to reset the swap.
+ choices:
+ - preview
+ - swap
+ - reset
+ default: preview
+ target_slot:
+ description:
+ - Name of target slot to swap. If set to None, then swap with production slot.
+ preserve_vnet:
+ description:
+ - C(True) to preserve virtual network to the slot during swap. Otherwise C(False).
+ type: bool
+ default: True
+ frameworks:
+ description:
+ - Set of run time framework settings. Each setting is a dictionary.
+ - See U(https://docs.microsoft.com/en-us/azure/app-service/app-service-web-overview) for more info.
+ suboptions:
+ name:
+ description:
+ - Name of the framework.
+ - Supported framework list for Windows web app and Linux web app is different.
+ - Windows web apps support C(java), C(net_framework), C(php), C(python), and C(node) from June 2018.
+ - Windows web apps support multiple framework at same time.
+ - Linux web apps support C(java), C(ruby), C(php), C(dotnetcore), and C(node) from June 2018.
+ - Linux web apps support only one framework.
+ - Java framework is mutually exclusive with others.
+ choices:
+ - java
+ - net_framework
+ - php
+ - python
+ - ruby
+ - dotnetcore
+ - node
+ version:
+ description:
+ - Version of the framework. For Linux web app supported value, see U(https://aka.ms/linux-stacks) for more info.
+ - C(net_framework) supported value sample, C(v4.0) for .NET 4.6 and C(v3.0) for .NET 3.5.
+ - C(php) supported value sample, C(5.5), C(5.6), C(7.0).
+ - C(python) supported value sample, C(5.5), C(5.6), C(7.0).
+ - C(node) supported value sample, C(6.6), C(6.9).
+ - C(dotnetcore) supported value sample, C(1.0), C(1.1), C(1.2).
+ - C(ruby) supported value sample, 2.3.
+ - C(java) supported value sample, C(1.9) for Windows web app. C(1.8) for Linux web app.
+ settings:
+ description:
+ - List of settings of the framework.
+ suboptions:
+ java_container:
+ description:
+ - Name of Java container. This is supported by specific framework C(java) onlys, for example C(Tomcat), C(Jetty).
+ java_container_version:
+ description:
+ - Version of Java container. This is supported by specific framework C(java) only.
+ - For C(Tomcat), for example C(8.0), C(8.5), C(9.0). For C(Jetty), for example C(9.1), C(9.3).
+ container_settings:
+ description:
+ - Web app slot container settings.
+ suboptions:
+ name:
+ description:
+ - Name of container, for example C(imagename:tag).
+ registry_server_url:
+ description:
+ - Container registry server URL, for example C(mydockerregistry.io).
+ registry_server_user:
+ description:
+ - The container registry server user name.
+ registry_server_password:
+ description:
+ - The container registry server password.
+ startup_file:
+ description:
+ - The slot startup file.
+ - This only applies for Linux web app slot.
+ app_settings:
+ description:
+ - Configure web app slot application settings. Suboptions are in key value pair format.
+ purge_app_settings:
+ description:
+ - Purge any existing application settings. Replace slot application settings with app_settings.
+ type: bool
+ deployment_source:
+ description:
+ - Deployment source for git.
+ suboptions:
+ url:
+ description:
+ - Repository URL of deployment source.
+ branch:
+ description:
+ - The branch name of the repository.
+ app_state:
+ description:
+ - Start/Stop/Restart the slot.
+ type: str
+ choices:
+ - started
+ - stopped
+ - restarted
+ default: started
+ state:
+ description:
+ - State of the Web App deployment slot.
+ - Use C(present) to create or update a slot and C(absent) to delete it.
+ default: present
+ choices:
+ - absent
+ - present
+
+extends_documentation_fragment:
+ - azure
+ - azure_tags
+
+author:
+ - Yunge Zhu(@yungezz)
+
+'''
+
+EXAMPLES = '''
+ - name: Create a webapp slot
+ azure_rm_webappslot:
+ resource_group: myResourceGroup
+ webapp_name: myJavaWebApp
+ name: stage
+ configuration_source: myJavaWebApp
+ app_settings:
+ testkey: testvalue
+
+ - name: swap the slot with production slot
+ azure_rm_webappslot:
+ resource_group: myResourceGroup
+ webapp_name: myJavaWebApp
+ name: stage
+ swap:
+ action: swap
+
+ - name: stop the slot
+ azure_rm_webappslot:
+ resource_group: myResourceGroup
+ webapp_name: myJavaWebApp
+ name: stage
+ app_state: stopped
+
+ - name: udpate a webapp slot app settings
+ azure_rm_webappslot:
+ resource_group: myResourceGroup
+ webapp_name: myJavaWebApp
+ name: stage
+ app_settings:
+ testkey: testvalue2
+
+ - name: udpate a webapp slot frameworks
+ azure_rm_webappslot:
+ resource_group: myResourceGroup
+ webapp_name: myJavaWebApp
+ name: stage
+ frameworks:
+ - name: "node"
+ version: "10.1"
+'''
+
+RETURN = '''
+id:
+ description:
+ - ID of current slot.
+ returned: always
+ type: str
+ sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/testapp/slots/stage1
+'''
+
+import time
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+ from msrestazure.azure_exceptions import CloudError
+ from msrest.polling import LROPoller
+ from msrest.serialization import Model
+ from azure.mgmt.web.models import (
+ site_config, app_service_plan, Site,
+ AppServicePlan, SkuDescription, NameValuePair
+ )
+except ImportError:
+ # This is handled in azure_rm_common
+ pass
+
+swap_spec = dict(
+ action=dict(
+ type='str',
+ choices=[
+ 'preview',
+ 'swap',
+ 'reset'
+ ],
+ default='preview'
+ ),
+ target_slot=dict(
+ type='str'
+ ),
+ preserve_vnet=dict(
+ type='bool',
+ default=True
+ )
+)
+
+container_settings_spec = dict(
+ name=dict(type='str', required=True),
+ registry_server_url=dict(type='str'),
+ registry_server_user=dict(type='str'),
+ registry_server_password=dict(type='str', no_log=True)
+)
+
+deployment_source_spec = dict(
+ url=dict(type='str'),
+ branch=dict(type='str')
+)
+
+
+framework_settings_spec = dict(
+ java_container=dict(type='str', required=True),
+ java_container_version=dict(type='str', required=True)
+)
+
+
+framework_spec = dict(
+ name=dict(
+ type='str',
+ required=True,
+ choices=['net_framework', 'java', 'php', 'node', 'python', 'dotnetcore', 'ruby']),
+ version=dict(type='str', required=True),
+ settings=dict(type='dict', options=framework_settings_spec)
+)
+
+
+def webapp_to_dict(webapp):
+ return dict(
+ id=webapp.id,
+ name=webapp.name,
+ location=webapp.location,
+ client_cert_enabled=webapp.client_cert_enabled,
+ enabled=webapp.enabled,
+ reserved=webapp.reserved,
+ client_affinity_enabled=webapp.client_affinity_enabled,
+ server_farm_id=webapp.server_farm_id,
+ host_names_disabled=webapp.host_names_disabled,
+ https_only=webapp.https_only if hasattr(webapp, 'https_only') else None,
+ skip_custom_domain_verification=webapp.skip_custom_domain_verification if hasattr(webapp, 'skip_custom_domain_verification') else None,
+ ttl_in_seconds=webapp.ttl_in_seconds if hasattr(webapp, 'ttl_in_seconds') else None,
+ state=webapp.state,
+ tags=webapp.tags if webapp.tags else None
+ )
+
+
+def slot_to_dict(slot):
+ return dict(
+ id=slot.id,
+ resource_group=slot.resource_group,
+ server_farm_id=slot.server_farm_id,
+ target_swap_slot=slot.target_swap_slot,
+ enabled_host_names=slot.enabled_host_names,
+ slot_swap_status=slot.slot_swap_status,
+ name=slot.name,
+ location=slot.location,
+ enabled=slot.enabled,
+ reserved=slot.reserved,
+ host_names_disabled=slot.host_names_disabled,
+ state=slot.state,
+ repository_site_name=slot.repository_site_name,
+ default_host_name=slot.default_host_name,
+ kind=slot.kind,
+ site_config=slot.site_config,
+ tags=slot.tags if slot.tags else None
+ )
+
+
+class Actions:
+ NoAction, CreateOrUpdate, UpdateAppSettings, Delete = range(4)
+
+
+class AzureRMWebAppSlots(AzureRMModuleBase):
+ """Configuration class for an Azure RM Web App slot resource"""
+
+ def __init__(self):
+ self.module_arg_spec = dict(
+ resource_group=dict(
+ type='str',
+ required=True
+ ),
+ name=dict(
+ type='str',
+ required=True
+ ),
+ webapp_name=dict(
+ type='str',
+ required=True
+ ),
+ location=dict(
+ type='str'
+ ),
+ configuration_source=dict(
+ type='str'
+ ),
+ auto_swap_slot_name=dict(
+ type='raw'
+ ),
+ swap=dict(
+ type='dict',
+ options=swap_spec
+ ),
+ frameworks=dict(
+ type='list',
+ elements='dict',
+ options=framework_spec
+ ),
+ container_settings=dict(
+ type='dict',
+ options=container_settings_spec
+ ),
+ deployment_source=dict(
+ type='dict',
+ options=deployment_source_spec
+ ),
+ startup_file=dict(
+ type='str'
+ ),
+ app_settings=dict(
+ type='dict'
+ ),
+ purge_app_settings=dict(
+ type='bool',
+ default=False
+ ),
+ app_state=dict(
+ type='str',
+ choices=['started', 'stopped', 'restarted'],
+ default='started'
+ ),
+ state=dict(
+ type='str',
+ default='present',
+ choices=['present', 'absent']
+ )
+ )
+
+ mutually_exclusive = [['container_settings', 'frameworks']]
+
+ self.resource_group = None
+ self.name = None
+ self.webapp_name = None
+ self.location = None
+
+ self.auto_swap_slot_name = None
+ self.swap = None
+ self.tags = None
+ self.startup_file = None
+ self.configuration_source = None
+ self.clone = False
+
+ # site config, e.g app settings, ssl
+ self.site_config = dict()
+ self.app_settings = dict()
+ self.app_settings_strDic = None
+
+ # siteSourceControl
+ self.deployment_source = dict()
+
+ # site, used at level creation, or update.
+ self.site = None
+
+ # property for internal usage, not used for sdk
+ self.container_settings = None
+
+ self.purge_app_settings = False
+ self.app_state = 'started'
+
+ self.results = dict(
+ changed=False,
+ id=None,
+ )
+ self.state = None
+ self.to_do = Actions.NoAction
+
+ self.frameworks = None
+
+ # set site_config value from kwargs
+ self.site_config_updatable_frameworks = ["net_framework_version",
+ "java_version",
+ "php_version",
+ "python_version",
+ "linux_fx_version"]
+
+ self.supported_linux_frameworks = ['ruby', 'php', 'dotnetcore', 'node', 'java']
+ self.supported_windows_frameworks = ['net_framework', 'php', 'python', 'node', 'java']
+
+ super(AzureRMWebAppSlots, self).__init__(derived_arg_spec=self.module_arg_spec,
+ mutually_exclusive=mutually_exclusive,
+ supports_check_mode=True,
+ supports_tags=True)
+
+ def exec_module(self, **kwargs):
+ """Main module execution method"""
+
+ for key in list(self.module_arg_spec.keys()) + ['tags']:
+ if hasattr(self, key):
+ setattr(self, key, kwargs[key])
+ elif kwargs[key] is not None:
+ if key == "scm_type":
+ self.site_config[key] = kwargs[key]
+
+ old_response = None
+ response = None
+ to_be_updated = False
+
+ # set location
+ resource_group = self.get_resource_group(self.resource_group)
+ if not self.location:
+ self.location = resource_group.location
+
+ # get web app
+ webapp_response = self.get_webapp()
+
+ if not webapp_response:
+ self.fail("Web app {0} does not exist in resource group {1}.".format(self.webapp_name, self.resource_group))
+
+ # get slot
+ old_response = self.get_slot()
+
+ # set is_linux
+ is_linux = True if webapp_response['reserved'] else False
+
+ if self.state == 'present':
+ if self.frameworks:
+ # java is mutually exclusive with other frameworks
+ if len(self.frameworks) > 1 and any(f['name'] == 'java' for f in self.frameworks):
+ self.fail('Java is mutually exclusive with other frameworks.')
+
+ if is_linux:
+ if len(self.frameworks) != 1:
+ self.fail('Can specify one framework only for Linux web app.')
+
+ if self.frameworks[0]['name'] not in self.supported_linux_frameworks:
+ self.fail('Unsupported framework {0} for Linux web app.'.format(self.frameworks[0]['name']))
+
+ self.site_config['linux_fx_version'] = (self.frameworks[0]['name'] + '|' + self.frameworks[0]['version']).upper()
+
+ if self.frameworks[0]['name'] == 'java':
+ if self.frameworks[0]['version'] != '8':
+ self.fail("Linux web app only supports java 8.")
+
+ if self.frameworks[0].get('settings', {}) and self.frameworks[0]['settings'].get('java_container', None) and \
+ self.frameworks[0]['settings']['java_container'].lower() != 'tomcat':
+ self.fail("Linux web app only supports tomcat container.")
+
+ if self.frameworks[0].get('settings', {}) and self.frameworks[0]['settings'].get('java_container', None) and \
+ self.frameworks[0]['settings']['java_container'].lower() == 'tomcat':
+ self.site_config['linux_fx_version'] = 'TOMCAT|' + self.frameworks[0]['settings']['java_container_version'] + '-jre8'
+ else:
+ self.site_config['linux_fx_version'] = 'JAVA|8-jre8'
+ else:
+ for fx in self.frameworks:
+ if fx.get('name') not in self.supported_windows_frameworks:
+ self.fail('Unsupported framework {0} for Windows web app.'.format(fx.get('name')))
+ else:
+ self.site_config[fx.get('name') + '_version'] = fx.get('version')
+
+ if 'settings' in fx and fx['settings'] is not None:
+ for key, value in fx['settings'].items():
+ self.site_config[key] = value
+
+ if not self.app_settings:
+ self.app_settings = dict()
+
+ if self.container_settings:
+ linux_fx_version = 'DOCKER|'
+
+ if self.container_settings.get('registry_server_url'):
+ self.app_settings['DOCKER_REGISTRY_SERVER_URL'] = 'https://' + self.container_settings['registry_server_url']
+
+ linux_fx_version += self.container_settings['registry_server_url'] + '/'
+
+ linux_fx_version += self.container_settings['name']
+
+ self.site_config['linux_fx_version'] = linux_fx_version
+
+ if self.container_settings.get('registry_server_user'):
+ self.app_settings['DOCKER_REGISTRY_SERVER_USERNAME'] = self.container_settings['registry_server_user']
+
+ if self.container_settings.get('registry_server_password'):
+ self.app_settings['DOCKER_REGISTRY_SERVER_PASSWORD'] = self.container_settings['registry_server_password']
+
+ # set auto_swap_slot_name
+ if self.auto_swap_slot_name and isinstance(self.auto_swap_slot_name, str):
+ self.site_config['auto_swap_slot_name'] = self.auto_swap_slot_name
+ if self.auto_swap_slot_name is False:
+ self.site_config['auto_swap_slot_name'] = None
+
+ # init site
+ self.site = Site(location=self.location, site_config=self.site_config)
+
+ # check if the slot already present in the webapp
+ if not old_response:
+ self.log("Web App slot doesn't exist")
+
+ to_be_updated = True
+ self.to_do = Actions.CreateOrUpdate
+ self.site.tags = self.tags
+
+ # if linux, setup startup_file
+ if self.startup_file:
+ self.site_config['app_command_line'] = self.startup_file
+
+ # set app setting
+ if self.app_settings:
+ app_settings = []
+ for key in self.app_settings.keys():
+ app_settings.append(NameValuePair(name=key, value=self.app_settings[key]))
+
+ self.site_config['app_settings'] = app_settings
+
+ # clone slot
+ if self.configuration_source:
+ self.clone = True
+
+ else:
+ # existing slot, do update
+ self.log("Web App slot already exists")
+
+ self.log('Result: {0}'.format(old_response))
+
+ update_tags, self.site.tags = self.update_tags(old_response.get('tags', None))
+
+ if update_tags:
+ to_be_updated = True
+
+ # check if site_config changed
+ old_config = self.get_configuration_slot(self.name)
+
+ if self.is_site_config_changed(old_config):
+ to_be_updated = True
+ self.to_do = Actions.CreateOrUpdate
+
+ self.app_settings_strDic = self.list_app_settings_slot(self.name)
+
+ # purge existing app_settings:
+ if self.purge_app_settings:
+ to_be_updated = True
+ self.to_do = Actions.UpdateAppSettings
+ self.app_settings_strDic = dict()
+
+ # check if app settings changed
+ if self.purge_app_settings or self.is_app_settings_changed():
+ to_be_updated = True
+ self.to_do = Actions.UpdateAppSettings
+
+ if self.app_settings:
+ for key in self.app_settings.keys():
+ self.app_settings_strDic[key] = self.app_settings[key]
+
+ elif self.state == 'absent':
+ if old_response:
+ self.log("Delete Web App slot")
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ self.delete_slot()
+
+ self.log('Web App slot deleted')
+
+ else:
+ self.log("Web app slot {0} not exists.".format(self.name))
+
+ if to_be_updated:
+ self.log('Need to Create/Update web app')
+ self.results['changed'] = True
+
+ if self.check_mode:
+ return self.results
+
+ if self.to_do == Actions.CreateOrUpdate:
+ response = self.create_update_slot()
+
+ self.results['id'] = response['id']
+
+ if self.clone:
+ self.clone_slot()
+
+ if self.to_do == Actions.UpdateAppSettings:
+ self.update_app_settings_slot()
+
+ slot = None
+ if response:
+ slot = response
+ if old_response:
+ slot = old_response
+
+ if slot:
+ if (slot['state'] != 'Stopped' and self.app_state == 'stopped') or \
+ (slot['state'] != 'Running' and self.app_state == 'started') or \
+ self.app_state == 'restarted':
+
+ self.results['changed'] = True
+ if self.check_mode:
+ return self.results
+
+ self.set_state_slot(self.app_state)
+
+ if self.swap:
+ self.results['changed'] = True
+ if self.check_mode:
+ return self.results
+
+ self.swap_slot()
+
+ return self.results
+
+ # compare site config
+ def is_site_config_changed(self, existing_config):
+ for fx_version in self.site_config_updatable_frameworks:
+ if self.site_config.get(fx_version):
+ if not getattr(existing_config, fx_version) or \
+ getattr(existing_config, fx_version).upper() != self.site_config.get(fx_version).upper():
+ return True
+
+ if self.auto_swap_slot_name is False and existing_config.auto_swap_slot_name is not None:
+ return True
+ elif self.auto_swap_slot_name and self.auto_swap_slot_name != getattr(existing_config, 'auto_swap_slot_name', None):
+ return True
+ return False
+
+ # comparing existing app setting with input, determine whether it's changed
+ def is_app_settings_changed(self):
+ if self.app_settings:
+ if len(self.app_settings_strDic) != len(self.app_settings):
+ return True
+
+ if self.app_settings_strDic != self.app_settings:
+ return True
+ return False
+
+ # comparing deployment source with input, determine whether it's changed
+ def is_deployment_source_changed(self, existing_webapp):
+ if self.deployment_source:
+ if self.deployment_source.get('url') \
+ and self.deployment_source['url'] != existing_webapp.get('site_source_control')['url']:
+ return True
+
+ if self.deployment_source.get('branch') \
+ and self.deployment_source['branch'] != existing_webapp.get('site_source_control')['branch']:
+ return True
+
+ return False
+
+ def create_update_slot(self):
+ '''
+ Creates or updates Web App slot with the specified configuration.
+
+ :return: deserialized Web App instance state dictionary
+ '''
+ self.log(
+ "Creating / Updating the Web App slot {0}".format(self.name))
+
+ try:
+ response = self.web_client.web_apps.create_or_update_slot(resource_group_name=self.resource_group,
+ slot=self.name,
+ name=self.webapp_name,
+ site_envelope=self.site)
+ if isinstance(response, LROPoller):
+ response = self.get_poller_result(response)
+
+ except CloudError as exc:
+ self.log('Error attempting to create the Web App slot instance.')
+ self.fail("Error creating the Web App slot: {0}".format(str(exc)))
+ return slot_to_dict(response)
+
+ def delete_slot(self):
+ '''
+ Deletes specified Web App slot in the specified subscription and resource group.
+
+ :return: True
+ '''
+ self.log("Deleting the Web App slot {0}".format(self.name))
+ try:
+ response = self.web_client.web_apps.delete_slot(resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ slot=self.name)
+ except CloudError as e:
+ self.log('Error attempting to delete the Web App slot.')
+ self.fail(
+ "Error deleting the Web App slots: {0}".format(str(e)))
+
+ return True
+
+ def get_webapp(self):
+ '''
+ Gets the properties of the specified Web App.
+
+ :return: deserialized Web App instance state dictionary
+ '''
+ self.log(
+ "Checking if the Web App instance {0} is present".format(self.webapp_name))
+
+ response = None
+
+ try:
+ response = self.web_client.web_apps.get(resource_group_name=self.resource_group,
+ name=self.webapp_name)
+
+ # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError
+ if response is not None:
+ self.log("Response : {0}".format(response))
+ self.log("Web App instance : {0} found".format(response.name))
+ return webapp_to_dict(response)
+
+ except CloudError as ex:
+ pass
+
+ self.log("Didn't find web app {0} in resource group {1}".format(
+ self.webapp_name, self.resource_group))
+
+ return False
+
+ def get_slot(self):
+ '''
+ Gets the properties of the specified Web App slot.
+
+ :return: deserialized Web App slot state dictionary
+ '''
+ self.log(
+ "Checking if the Web App slot {0} is present".format(self.name))
+
+ response = None
+
+ try:
+ response = self.web_client.web_apps.get_slot(resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ slot=self.name)
+
+ # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError
+ if response is not None:
+ self.log("Response : {0}".format(response))
+ self.log("Web App slot: {0} found".format(response.name))
+ return slot_to_dict(response)
+
+ except CloudError as ex:
+ pass
+
+ self.log("Does not find web app slot {0} in resource group {1}".format(self.name, self.resource_group))
+
+ return False
+
+ def list_app_settings(self):
+ '''
+ List webapp application settings
+ :return: deserialized list response
+ '''
+ self.log("List webapp application setting")
+
+ try:
+
+ response = self.web_client.web_apps.list_application_settings(
+ resource_group_name=self.resource_group, name=self.webapp_name)
+ self.log("Response : {0}".format(response))
+
+ return response.properties
+ except CloudError as ex:
+ self.fail("Failed to list application settings for web app {0} in resource group {1}: {2}".format(
+ self.name, self.resource_group, str(ex)))
+
+ def list_app_settings_slot(self, slot_name):
+ '''
+ List application settings
+ :return: deserialized list response
+ '''
+ self.log("List application setting")
+
+ try:
+
+ response = self.web_client.web_apps.list_application_settings_slot(
+ resource_group_name=self.resource_group, name=self.webapp_name, slot=slot_name)
+ self.log("Response : {0}".format(response))
+
+ return response.properties
+ except CloudError as ex:
+ self.fail("Failed to list application settings for web app slot {0} in resource group {1}: {2}".format(
+ self.name, self.resource_group, str(ex)))
+
+ def update_app_settings_slot(self, slot_name=None, app_settings=None):
+ '''
+ Update application settings
+ :return: deserialized updating response
+ '''
+ self.log("Update application setting")
+
+ if slot_name is None:
+ slot_name = self.name
+ if app_settings is None:
+ app_settings = self.app_settings_strDic
+ try:
+ response = self.web_client.web_apps.update_application_settings_slot(resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ slot=slot_name,
+ kind=None,
+ properties=app_settings)
+ self.log("Response : {0}".format(response))
+
+ return response.as_dict()
+ except CloudError as ex:
+ self.fail("Failed to update application settings for web app slot {0} in resource group {1}: {2}".format(
+ self.name, self.resource_group, str(ex)))
+
+ return response
+
+ def create_or_update_source_control_slot(self):
+ '''
+ Update site source control
+ :return: deserialized updating response
+ '''
+ self.log("Update site source control")
+
+ if self.deployment_source is None:
+ return False
+
+ self.deployment_source['is_manual_integration'] = False
+ self.deployment_source['is_mercurial'] = False
+
+ try:
+ response = self.web_client.web_client.create_or_update_source_control_slot(
+ resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ site_source_control=self.deployment_source,
+ slot=self.name)
+ self.log("Response : {0}".format(response))
+
+ return response.as_dict()
+ except CloudError as ex:
+ self.fail("Failed to update site source control for web app slot {0} in resource group {1}: {2}".format(
+ self.name, self.resource_group, str(ex)))
+
+ def get_configuration(self):
+ '''
+ Get web app configuration
+ :return: deserialized web app configuration response
+ '''
+ self.log("Get web app configuration")
+
+ try:
+
+ response = self.web_client.web_apps.get_configuration(
+ resource_group_name=self.resource_group, name=self.webapp_name)
+ self.log("Response : {0}".format(response))
+
+ return response
+ except CloudError as ex:
+ self.fail("Failed to get configuration for web app {0} in resource group {1}: {2}".format(
+ self.webapp_name, self.resource_group, str(ex)))
+
+ def get_configuration_slot(self, slot_name):
+ '''
+ Get slot configuration
+ :return: deserialized slot configuration response
+ '''
+ self.log("Get web app slot configuration")
+
+ try:
+
+ response = self.web_client.web_apps.get_configuration_slot(
+ resource_group_name=self.resource_group, name=self.webapp_name, slot=slot_name)
+ self.log("Response : {0}".format(response))
+
+ return response
+ except CloudError as ex:
+ self.fail("Failed to get configuration for web app slot {0} in resource group {1}: {2}".format(
+ slot_name, self.resource_group, str(ex)))
+
+ def update_configuration_slot(self, slot_name=None, site_config=None):
+ '''
+ Update slot configuration
+ :return: deserialized slot configuration response
+ '''
+ self.log("Update web app slot configuration")
+
+ if slot_name is None:
+ slot_name = self.name
+ if site_config is None:
+ site_config = self.site_config
+ try:
+
+ response = self.web_client.web_apps.update_configuration_slot(
+ resource_group_name=self.resource_group, name=self.webapp_name, slot=slot_name, site_config=site_config)
+ self.log("Response : {0}".format(response))
+
+ return response
+ except CloudError as ex:
+ self.fail("Failed to update configuration for web app slot {0} in resource group {1}: {2}".format(
+ slot_name, self.resource_group, str(ex)))
+
+ def set_state_slot(self, appstate):
+ '''
+ Start/stop/restart web app slot
+ :return: deserialized updating response
+ '''
+ try:
+ if appstate == 'started':
+ response = self.web_client.web_apps.start_slot(resource_group_name=self.resource_group, name=self.webapp_name, slot=self.name)
+ elif appstate == 'stopped':
+ response = self.web_client.web_apps.stop_slot(resource_group_name=self.resource_group, name=self.webapp_name, slot=self.name)
+ elif appstate == 'restarted':
+ response = self.web_client.web_apps.restart_slot(resource_group_name=self.resource_group, name=self.webapp_name, slot=self.name)
+ else:
+ self.fail("Invalid web app slot state {0}".format(appstate))
+
+ self.log("Response : {0}".format(response))
+
+ return response
+ except CloudError as ex:
+ request_id = ex.request_id if ex.request_id else ''
+ self.fail("Failed to {0} web app slot {1} in resource group {2}, request_id {3} - {4}".format(
+ appstate, self.name, self.resource_group, request_id, str(ex)))
+
+ def swap_slot(self):
+ '''
+ Swap slot
+ :return: deserialized response
+ '''
+ self.log("Swap slot")
+
+ try:
+ if self.swap['action'] == 'swap':
+ if self.swap['target_slot'] is None:
+ response = self.web_client.web_apps.swap_slot_with_production(resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ target_slot=self.name,
+ preserve_vnet=self.swap['preserve_vnet'])
+ else:
+ response = self.web_client.web_apps.swap_slot_slot(resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ slot=self.name,
+ target_slot=self.swap['target_slot'],
+ preserve_vnet=self.swap['preserve_vnet'])
+ elif self.swap['action'] == 'preview':
+ if self.swap['target_slot'] is None:
+ response = self.web_client.web_apps.apply_slot_config_to_production(resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ target_slot=self.name,
+ preserve_vnet=self.swap['preserve_vnet'])
+ else:
+ response = self.web_client.web_apps.apply_slot_configuration_slot(resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ slot=self.name,
+ target_slot=self.swap['target_slot'],
+ preserve_vnet=self.swap['preserve_vnet'])
+ elif self.swap['action'] == 'reset':
+ if self.swap['target_slot'] is None:
+ response = self.web_client.web_apps.reset_production_slot_config(resource_group_name=self.resource_group,
+ name=self.webapp_name)
+ else:
+ response = self.web_client.web_apps.reset_slot_configuration_slot(resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ slot=self.swap['target_slot'])
+ response = self.web_client.web_apps.reset_slot_configuration_slot(resource_group_name=self.resource_group,
+ name=self.webapp_name,
+ slot=self.name)
+
+ self.log("Response : {0}".format(response))
+
+ return response
+ except CloudError as ex:
+ self.fail("Failed to swap web app slot {0} in resource group {1}: {2}".format(self.name, self.resource_group, str(ex)))
+
+ def clone_slot(self):
+ if self.configuration_source:
+ src_slot = None if self.configuration_source.lower() == self.webapp_name.lower() else self.configuration_source
+
+ if src_slot is None:
+ site_config_clone_from = self.get_configuration()
+ else:
+ site_config_clone_from = self.get_configuration_slot(slot_name=src_slot)
+
+ self.update_configuration_slot(site_config=site_config_clone_from)
+
+ if src_slot is None:
+ app_setting_clone_from = self.list_app_settings()
+ else:
+ app_setting_clone_from = self.list_app_settings_slot(src_slot)
+
+ if self.app_settings:
+ app_setting_clone_from.update(self.app_settings)
+
+ self.update_app_settings_slot(app_settings=app_setting_clone_from)
+
+
+def main():
+ """Main execution"""
+ AzureRMWebAppSlots()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/cloud_init_data_facts.py b/test/support/integration/plugins/modules/cloud_init_data_facts.py
new file mode 100644
index 00000000..4f871b99
--- /dev/null
+++ b/test/support/integration/plugins/modules/cloud_init_data_facts.py
@@ -0,0 +1,134 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+#
+# (c) 2018, René Moser <mail@renemoser.net>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: cloud_init_data_facts
+short_description: Retrieve facts of cloud-init.
+description:
+ - Gathers facts by reading the status.json and result.json of cloud-init.
+version_added: 2.6
+author: René Moser (@resmo)
+options:
+ filter:
+ description:
+ - Filter facts
+ choices: [ status, result ]
+notes:
+ - See http://cloudinit.readthedocs.io/ for more information about cloud-init.
+'''
+
+EXAMPLES = '''
+- name: Gather all facts of cloud init
+ cloud_init_data_facts:
+ register: result
+
+- debug:
+ var: result
+
+- name: Wait for cloud init to finish
+ cloud_init_data_facts:
+ filter: status
+ register: res
+ until: "res.cloud_init_data_facts.status.v1.stage is defined and not res.cloud_init_data_facts.status.v1.stage"
+ retries: 50
+ delay: 5
+'''
+
+RETURN = '''
+---
+cloud_init_data_facts:
+ description: Facts of result and status.
+ returned: success
+ type: dict
+ sample: '{
+ "status": {
+ "v1": {
+ "datasource": "DataSourceCloudStack",
+ "errors": []
+ },
+ "result": {
+ "v1": {
+ "datasource": "DataSourceCloudStack",
+ "init": {
+ "errors": [],
+ "finished": 1522066377.0185432,
+ "start": 1522066375.2648022
+ },
+ "init-local": {
+ "errors": [],
+ "finished": 1522066373.70919,
+ "start": 1522066373.4726632
+ },
+ "modules-config": {
+ "errors": [],
+ "finished": 1522066380.9097016,
+ "start": 1522066379.0011985
+ },
+ "modules-final": {
+ "errors": [],
+ "finished": 1522066383.56594,
+ "start": 1522066382.3449218
+ },
+ "stage": null
+ }
+ }'
+'''
+
+import os
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils._text import to_text
+
+
+CLOUD_INIT_PATH = "/var/lib/cloud/data/"
+
+
+def gather_cloud_init_data_facts(module):
+ res = {
+ 'cloud_init_data_facts': dict()
+ }
+
+ for i in ['result', 'status']:
+ filter = module.params.get('filter')
+ if filter is None or filter == i:
+ res['cloud_init_data_facts'][i] = dict()
+ json_file = CLOUD_INIT_PATH + i + '.json'
+
+ if os.path.exists(json_file):
+ f = open(json_file, 'rb')
+ contents = to_text(f.read(), errors='surrogate_or_strict')
+ f.close()
+
+ if contents:
+ res['cloud_init_data_facts'][i] = module.from_json(contents)
+ return res
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ filter=dict(choices=['result', 'status']),
+ ),
+ supports_check_mode=True,
+ )
+
+ facts = gather_cloud_init_data_facts(module)
+ result = dict(changed=False, ansible_facts=facts, **facts)
+ module.exit_json(**result)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/cloudformation.py b/test/support/integration/plugins/modules/cloudformation.py
new file mode 100644
index 00000000..cd031465
--- /dev/null
+++ b/test/support/integration/plugins/modules/cloudformation.py
@@ -0,0 +1,837 @@
+#!/usr/bin/python
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+
+DOCUMENTATION = '''
+---
+module: cloudformation
+short_description: Create or delete an AWS CloudFormation stack
+description:
+ - Launches or updates an AWS CloudFormation stack and waits for it complete.
+notes:
+ - CloudFormation features change often, and this module tries to keep up. That means your botocore version should be fresh.
+ The version listed in the requirements is the oldest version that works with the module as a whole.
+ Some features may require recent versions, and we do not pinpoint a minimum version for each feature.
+ Instead of relying on the minimum version, keep botocore up to date. AWS is always releasing features and fixing bugs.
+version_added: "1.1"
+options:
+ stack_name:
+ description:
+ - Name of the CloudFormation stack.
+ required: true
+ type: str
+ disable_rollback:
+ description:
+ - If a stacks fails to form, rollback will remove the stack.
+ default: false
+ type: bool
+ on_create_failure:
+ description:
+ - Action to take upon failure of stack creation. Incompatible with the I(disable_rollback) option.
+ choices:
+ - DO_NOTHING
+ - ROLLBACK
+ - DELETE
+ version_added: "2.8"
+ type: str
+ create_timeout:
+ description:
+ - The amount of time (in minutes) that can pass before the stack status becomes CREATE_FAILED
+ version_added: "2.6"
+ type: int
+ template_parameters:
+ description:
+ - A list of hashes of all the template variables for the stack. The value can be a string or a dict.
+ - Dict can be used to set additional template parameter attributes like UsePreviousValue (see example).
+ default: {}
+ type: dict
+ state:
+ description:
+ - If I(state=present), stack will be created.
+ - If I(state=present) and if stack exists and template has changed, it will be updated.
+ - If I(state=absent), stack will be removed.
+ default: present
+ choices: [ present, absent ]
+ type: str
+ template:
+ description:
+ - The local path of the CloudFormation template.
+ - This must be the full path to the file, relative to the working directory. If using roles this may look
+ like C(roles/cloudformation/files/cloudformation-example.json).
+ - If I(state=present) and the stack does not exist yet, either I(template), I(template_body) or I(template_url)
+ must be specified (but only one of them).
+ - If I(state=present), the stack does exist, and neither I(template),
+ I(template_body) nor I(template_url) are specified, the previous template will be reused.
+ type: path
+ notification_arns:
+ description:
+ - A comma separated list of Simple Notification Service (SNS) topic ARNs to publish stack related events.
+ version_added: "2.0"
+ type: str
+ stack_policy:
+ description:
+ - The path of the CloudFormation stack policy. A policy cannot be removed once placed, but it can be modified.
+ for instance, allow all updates U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/protect-stack-resources.html#d0e9051)
+ version_added: "1.9"
+ type: str
+ tags:
+ description:
+ - Dictionary of tags to associate with stack and its resources during stack creation.
+ - Can be updated later, updating tags removes previous entries.
+ version_added: "1.4"
+ type: dict
+ template_url:
+ description:
+ - Location of file containing the template body. The URL must point to a template (max size 307,200 bytes) located in an
+ S3 bucket in the same region as the stack.
+ - If I(state=present) and the stack does not exist yet, either I(template), I(template_body) or I(template_url)
+ must be specified (but only one of them).
+ - If I(state=present), the stack does exist, and neither I(template), I(template_body) nor I(template_url) are specified,
+ the previous template will be reused.
+ version_added: "2.0"
+ type: str
+ create_changeset:
+ description:
+ - "If stack already exists create a changeset instead of directly applying changes. See the AWS Change Sets docs
+ U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-updating-stacks-changesets.html)."
+ - "WARNING: if the stack does not exist, it will be created without changeset. If I(state=absent), the stack will be
+ deleted immediately with no changeset."
+ type: bool
+ default: false
+ version_added: "2.4"
+ changeset_name:
+ description:
+ - Name given to the changeset when creating a changeset.
+ - Only used when I(create_changeset=true).
+ - By default a name prefixed with Ansible-STACKNAME is generated based on input parameters.
+ See the AWS Change Sets docs for more information
+ U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-updating-stacks-changesets.html)
+ version_added: "2.4"
+ type: str
+ template_format:
+ description:
+ - This parameter is ignored since Ansible 2.3 and will be removed in Ansible 2.14.
+ - Templates are now passed raw to CloudFormation regardless of format.
+ version_added: "2.0"
+ type: str
+ role_arn:
+ description:
+ - The role that AWS CloudFormation assumes to create the stack. See the AWS CloudFormation Service Role
+ docs U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-iam-servicerole.html)
+ version_added: "2.3"
+ type: str
+ termination_protection:
+ description:
+ - Enable or disable termination protection on the stack. Only works with botocore >= 1.7.18.
+ type: bool
+ version_added: "2.5"
+ template_body:
+ description:
+ - Template body. Use this to pass in the actual body of the CloudFormation template.
+ - If I(state=present) and the stack does not exist yet, either I(template), I(template_body) or I(template_url)
+ must be specified (but only one of them).
+ - If I(state=present), the stack does exist, and neither I(template), I(template_body) nor I(template_url)
+ are specified, the previous template will be reused.
+ version_added: "2.5"
+ type: str
+ events_limit:
+ description:
+ - Maximum number of CloudFormation events to fetch from a stack when creating or updating it.
+ default: 200
+ version_added: "2.7"
+ type: int
+ backoff_delay:
+ description:
+ - Number of seconds to wait for the next retry.
+ default: 3
+ version_added: "2.8"
+ type: int
+ required: False
+ backoff_max_delay:
+ description:
+ - Maximum amount of time to wait between retries.
+ default: 30
+ version_added: "2.8"
+ type: int
+ required: False
+ backoff_retries:
+ description:
+ - Number of times to retry operation.
+ - AWS API throttling mechanism fails CloudFormation module so we have to retry a couple of times.
+ default: 10
+ version_added: "2.8"
+ type: int
+ required: False
+ capabilities:
+ description:
+ - Specify capabilities that stack template contains.
+ - Valid values are C(CAPABILITY_IAM), C(CAPABILITY_NAMED_IAM) and C(CAPABILITY_AUTO_EXPAND).
+ type: list
+ elements: str
+ version_added: "2.8"
+ default: [ CAPABILITY_IAM, CAPABILITY_NAMED_IAM ]
+
+author: "James S. Martin (@jsmartin)"
+extends_documentation_fragment:
+- aws
+- ec2
+requirements: [ boto3, botocore>=1.5.45 ]
+'''
+
+EXAMPLES = '''
+- name: create a cloudformation stack
+ cloudformation:
+ stack_name: "ansible-cloudformation"
+ state: "present"
+ region: "us-east-1"
+ disable_rollback: true
+ template: "files/cloudformation-example.json"
+ template_parameters:
+ KeyName: "jmartin"
+ DiskType: "ephemeral"
+ InstanceType: "m1.small"
+ ClusterSize: 3
+ tags:
+ Stack: "ansible-cloudformation"
+
+# Basic role example
+- name: create a stack, specify role that cloudformation assumes
+ cloudformation:
+ stack_name: "ansible-cloudformation"
+ state: "present"
+ region: "us-east-1"
+ disable_rollback: true
+ template: "roles/cloudformation/files/cloudformation-example.json"
+ role_arn: 'arn:aws:iam::123456789012:role/cloudformation-iam-role'
+
+- name: delete a stack
+ cloudformation:
+ stack_name: "ansible-cloudformation-old"
+ state: "absent"
+
+# Create a stack, pass in template from a URL, disable rollback if stack creation fails,
+# pass in some parameters to the template, provide tags for resources created
+- name: create a stack, pass in the template via an URL
+ cloudformation:
+ stack_name: "ansible-cloudformation"
+ state: present
+ region: us-east-1
+ disable_rollback: true
+ template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template
+ template_parameters:
+ KeyName: jmartin
+ DiskType: ephemeral
+ InstanceType: m1.small
+ ClusterSize: 3
+ tags:
+ Stack: ansible-cloudformation
+
+# Create a stack, passing in template body using lookup of Jinja2 template, disable rollback if stack creation fails,
+# pass in some parameters to the template, provide tags for resources created
+- name: create a stack, pass in the template body via lookup template
+ cloudformation:
+ stack_name: "ansible-cloudformation"
+ state: present
+ region: us-east-1
+ disable_rollback: true
+ template_body: "{{ lookup('template', 'cloudformation.j2') }}"
+ template_parameters:
+ KeyName: jmartin
+ DiskType: ephemeral
+ InstanceType: m1.small
+ ClusterSize: 3
+ tags:
+ Stack: ansible-cloudformation
+
+# Pass a template parameter which uses CloudFormation's UsePreviousValue attribute
+# When use_previous_value is set to True, the given value will be ignored and
+# CloudFormation will use the value from a previously submitted template.
+# If use_previous_value is set to False (default) the given value is used.
+- cloudformation:
+ stack_name: "ansible-cloudformation"
+ state: "present"
+ region: "us-east-1"
+ template: "files/cloudformation-example.json"
+ template_parameters:
+ DBSnapshotIdentifier:
+ use_previous_value: True
+ value: arn:aws:rds:es-east-1:000000000000:snapshot:rds:my-db-snapshot
+ DBName:
+ use_previous_value: True
+ tags:
+ Stack: "ansible-cloudformation"
+
+# Enable termination protection on a stack.
+# If the stack already exists, this will update its termination protection
+- name: enable termination protection during stack creation
+ cloudformation:
+ stack_name: my_stack
+ state: present
+ template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template
+ termination_protection: yes
+
+# Configure TimeoutInMinutes before the stack status becomes CREATE_FAILED
+# In this case, if disable_rollback is not set or is set to false, the stack will be rolled back.
+- name: enable termination protection during stack creation
+ cloudformation:
+ stack_name: my_stack
+ state: present
+ template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template
+ create_timeout: 5
+
+# Configure rollback behaviour on the unsuccessful creation of a stack allowing
+# CloudFormation to clean up, or do nothing in the event of an unsuccessful
+# deployment
+# In this case, if on_create_failure is set to "DELETE", it will clean up the stack if
+# it fails to create
+- name: create stack which will delete on creation failure
+ cloudformation:
+ stack_name: my_stack
+ state: present
+ template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template
+ on_create_failure: DELETE
+'''
+
+RETURN = '''
+events:
+ type: list
+ description: Most recent events in CloudFormation's event log. This may be from a previous run in some cases.
+ returned: always
+ sample: ["StackEvent AWS::CloudFormation::Stack stackname UPDATE_COMPLETE", "StackEvent AWS::CloudFormation::Stack stackname UPDATE_COMPLETE_CLEANUP_IN_PROGRESS"]
+log:
+ description: Debugging logs. Useful when modifying or finding an error.
+ returned: always
+ type: list
+ sample: ["updating stack"]
+change_set_id:
+ description: The ID of the stack change set if one was created
+ returned: I(state=present) and I(create_changeset=true)
+ type: str
+ sample: "arn:aws:cloudformation:us-east-1:012345678901:changeSet/Ansible-StackName-f4496805bd1b2be824d1e315c6884247ede41eb0"
+stack_resources:
+ description: AWS stack resources and their status. List of dictionaries, one dict per resource.
+ returned: state == present
+ type: list
+ sample: [
+ {
+ "last_updated_time": "2016-10-11T19:40:14.979000+00:00",
+ "logical_resource_id": "CFTestSg",
+ "physical_resource_id": "cloudformation2-CFTestSg-16UQ4CYQ57O9F",
+ "resource_type": "AWS::EC2::SecurityGroup",
+ "status": "UPDATE_COMPLETE",
+ "status_reason": null
+ }
+ ]
+stack_outputs:
+ type: dict
+ description: A key:value dictionary of all the stack outputs currently defined. If there are no stack outputs, it is an empty dictionary.
+ returned: state == present
+ sample: {"MySg": "AnsibleModuleTestYAML-CFTestSg-C8UVS567B6NS"}
+''' # NOQA
+
+import json
+import time
+import uuid
+import traceback
+from hashlib import sha1
+
+try:
+ import boto3
+ import botocore
+ HAS_BOTO3 = True
+except ImportError:
+ HAS_BOTO3 = False
+
+from ansible.module_utils.ec2 import ansible_dict_to_boto3_tag_list, AWSRetry, boto3_conn, boto_exception, ec2_argument_spec, get_aws_connection_info
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils._text import to_bytes, to_native
+
+
+def get_stack_events(cfn, stack_name, events_limit, token_filter=None):
+ '''This event data was never correct, it worked as a side effect. So the v2.3 format is different.'''
+ ret = {'events': [], 'log': []}
+
+ try:
+ pg = cfn.get_paginator(
+ 'describe_stack_events'
+ ).paginate(
+ StackName=stack_name,
+ PaginationConfig={'MaxItems': events_limit}
+ )
+ if token_filter is not None:
+ events = list(pg.search(
+ "StackEvents[?ClientRequestToken == '{0}']".format(token_filter)
+ ))
+ else:
+ events = list(pg.search("StackEvents[*]"))
+ except (botocore.exceptions.ValidationError, botocore.exceptions.ClientError) as err:
+ error_msg = boto_exception(err)
+ if 'does not exist' in error_msg:
+ # missing stack, don't bail.
+ ret['log'].append('Stack does not exist.')
+ return ret
+ ret['log'].append('Unknown error: ' + str(error_msg))
+ return ret
+
+ for e in events:
+ eventline = 'StackEvent {ResourceType} {LogicalResourceId} {ResourceStatus}'.format(**e)
+ ret['events'].append(eventline)
+
+ if e['ResourceStatus'].endswith('FAILED'):
+ failline = '{ResourceType} {LogicalResourceId} {ResourceStatus}: {ResourceStatusReason}'.format(**e)
+ ret['log'].append(failline)
+
+ return ret
+
+
+def create_stack(module, stack_params, cfn, events_limit):
+ if 'TemplateBody' not in stack_params and 'TemplateURL' not in stack_params:
+ module.fail_json(msg="Either 'template', 'template_body' or 'template_url' is required when the stack does not exist.")
+
+ # 'DisableRollback', 'TimeoutInMinutes', 'EnableTerminationProtection' and
+ # 'OnFailure' only apply on creation, not update.
+ if module.params.get('on_create_failure') is not None:
+ stack_params['OnFailure'] = module.params['on_create_failure']
+ else:
+ stack_params['DisableRollback'] = module.params['disable_rollback']
+
+ if module.params.get('create_timeout') is not None:
+ stack_params['TimeoutInMinutes'] = module.params['create_timeout']
+ if module.params.get('termination_protection') is not None:
+ if boto_supports_termination_protection(cfn):
+ stack_params['EnableTerminationProtection'] = bool(module.params.get('termination_protection'))
+ else:
+ module.fail_json(msg="termination_protection parameter requires botocore >= 1.7.18")
+
+ try:
+ response = cfn.create_stack(**stack_params)
+ # Use stack ID to follow stack state in case of on_create_failure = DELETE
+ result = stack_operation(cfn, response['StackId'], 'CREATE', events_limit, stack_params.get('ClientRequestToken', None))
+ except Exception as err:
+ error_msg = boto_exception(err)
+ module.fail_json(msg="Failed to create stack {0}: {1}.".format(stack_params.get('StackName'), error_msg), exception=traceback.format_exc())
+ if not result:
+ module.fail_json(msg="empty result")
+ return result
+
+
+def list_changesets(cfn, stack_name):
+ res = cfn.list_change_sets(StackName=stack_name)
+ return [cs['ChangeSetName'] for cs in res['Summaries']]
+
+
+def create_changeset(module, stack_params, cfn, events_limit):
+ if 'TemplateBody' not in stack_params and 'TemplateURL' not in stack_params:
+ module.fail_json(msg="Either 'template' or 'template_url' is required.")
+ if module.params['changeset_name'] is not None:
+ stack_params['ChangeSetName'] = module.params['changeset_name']
+
+ # changesets don't accept ClientRequestToken parameters
+ stack_params.pop('ClientRequestToken', None)
+
+ try:
+ changeset_name = build_changeset_name(stack_params)
+ stack_params['ChangeSetName'] = changeset_name
+
+ # Determine if this changeset already exists
+ pending_changesets = list_changesets(cfn, stack_params['StackName'])
+ if changeset_name in pending_changesets:
+ warning = 'WARNING: %d pending changeset(s) exist(s) for this stack!' % len(pending_changesets)
+ result = dict(changed=False, output='ChangeSet %s already exists.' % changeset_name, warnings=[warning])
+ else:
+ cs = cfn.create_change_set(**stack_params)
+ # Make sure we don't enter an infinite loop
+ time_end = time.time() + 600
+ while time.time() < time_end:
+ try:
+ newcs = cfn.describe_change_set(ChangeSetName=cs['Id'])
+ except botocore.exceptions.BotoCoreError as err:
+ error_msg = boto_exception(err)
+ module.fail_json(msg=error_msg)
+ if newcs['Status'] == 'CREATE_PENDING' or newcs['Status'] == 'CREATE_IN_PROGRESS':
+ time.sleep(1)
+ elif newcs['Status'] == 'FAILED' and "The submitted information didn't contain changes" in newcs['StatusReason']:
+ cfn.delete_change_set(ChangeSetName=cs['Id'])
+ result = dict(changed=False,
+ output='The created Change Set did not contain any changes to this stack and was deleted.')
+ # a failed change set does not trigger any stack events so we just want to
+ # skip any further processing of result and just return it directly
+ return result
+ else:
+ break
+ # Lets not hog the cpu/spam the AWS API
+ time.sleep(1)
+ result = stack_operation(cfn, stack_params['StackName'], 'CREATE_CHANGESET', events_limit)
+ result['change_set_id'] = cs['Id']
+ result['warnings'] = ['Created changeset named %s for stack %s' % (changeset_name, stack_params['StackName']),
+ 'You can execute it using: aws cloudformation execute-change-set --change-set-name %s' % cs['Id'],
+ 'NOTE that dependencies on this stack might fail due to pending changes!']
+ except Exception as err:
+ error_msg = boto_exception(err)
+ if 'No updates are to be performed.' in error_msg:
+ result = dict(changed=False, output='Stack is already up-to-date.')
+ else:
+ module.fail_json(msg="Failed to create change set: {0}".format(error_msg), exception=traceback.format_exc())
+
+ if not result:
+ module.fail_json(msg="empty result")
+ return result
+
+
+def update_stack(module, stack_params, cfn, events_limit):
+ if 'TemplateBody' not in stack_params and 'TemplateURL' not in stack_params:
+ stack_params['UsePreviousTemplate'] = True
+
+ # if the state is present and the stack already exists, we try to update it.
+ # AWS will tell us if the stack template and parameters are the same and
+ # don't need to be updated.
+ try:
+ cfn.update_stack(**stack_params)
+ result = stack_operation(cfn, stack_params['StackName'], 'UPDATE', events_limit, stack_params.get('ClientRequestToken', None))
+ except Exception as err:
+ error_msg = boto_exception(err)
+ if 'No updates are to be performed.' in error_msg:
+ result = dict(changed=False, output='Stack is already up-to-date.')
+ else:
+ module.fail_json(msg="Failed to update stack {0}: {1}".format(stack_params.get('StackName'), error_msg), exception=traceback.format_exc())
+ if not result:
+ module.fail_json(msg="empty result")
+ return result
+
+
+def update_termination_protection(module, cfn, stack_name, desired_termination_protection_state):
+ '''updates termination protection of a stack'''
+ if not boto_supports_termination_protection(cfn):
+ module.fail_json(msg="termination_protection parameter requires botocore >= 1.7.18")
+ stack = get_stack_facts(cfn, stack_name)
+ if stack:
+ if stack['EnableTerminationProtection'] is not desired_termination_protection_state:
+ try:
+ cfn.update_termination_protection(
+ EnableTerminationProtection=desired_termination_protection_state,
+ StackName=stack_name)
+ except botocore.exceptions.ClientError as e:
+ module.fail_json(msg=boto_exception(e), exception=traceback.format_exc())
+
+
+def boto_supports_termination_protection(cfn):
+ '''termination protection was added in botocore 1.7.18'''
+ return hasattr(cfn, "update_termination_protection")
+
+
+def stack_operation(cfn, stack_name, operation, events_limit, op_token=None):
+ '''gets the status of a stack while it is created/updated/deleted'''
+ existed = []
+ while True:
+ try:
+ stack = get_stack_facts(cfn, stack_name)
+ existed.append('yes')
+ except Exception:
+ # If the stack previously existed, and now can't be found then it's
+ # been deleted successfully.
+ if 'yes' in existed or operation == 'DELETE': # stacks may delete fast, look in a few ways.
+ ret = get_stack_events(cfn, stack_name, events_limit, op_token)
+ ret.update({'changed': True, 'output': 'Stack Deleted'})
+ return ret
+ else:
+ return {'changed': True, 'failed': True, 'output': 'Stack Not Found', 'exception': traceback.format_exc()}
+ ret = get_stack_events(cfn, stack_name, events_limit, op_token)
+ if not stack:
+ if 'yes' in existed or operation == 'DELETE': # stacks may delete fast, look in a few ways.
+ ret = get_stack_events(cfn, stack_name, events_limit, op_token)
+ ret.update({'changed': True, 'output': 'Stack Deleted'})
+ return ret
+ else:
+ ret.update({'changed': False, 'failed': True, 'output': 'Stack not found.'})
+ return ret
+ # it covers ROLLBACK_COMPLETE and UPDATE_ROLLBACK_COMPLETE
+ # Possible states: https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-describing-stacks.html#w1ab2c15c17c21c13
+ elif stack['StackStatus'].endswith('ROLLBACK_COMPLETE') and operation != 'CREATE_CHANGESET':
+ ret.update({'changed': True, 'failed': True, 'output': 'Problem with %s. Rollback complete' % operation})
+ return ret
+ elif stack['StackStatus'] == 'DELETE_COMPLETE' and operation == 'CREATE':
+ ret.update({'changed': True, 'failed': True, 'output': 'Stack create failed. Delete complete.'})
+ return ret
+ # note the ordering of ROLLBACK_COMPLETE, DELETE_COMPLETE, and COMPLETE, because otherwise COMPLETE will match all cases.
+ elif stack['StackStatus'].endswith('_COMPLETE'):
+ ret.update({'changed': True, 'output': 'Stack %s complete' % operation})
+ return ret
+ elif stack['StackStatus'].endswith('_ROLLBACK_FAILED'):
+ ret.update({'changed': True, 'failed': True, 'output': 'Stack %s rollback failed' % operation})
+ return ret
+ # note the ordering of ROLLBACK_FAILED and FAILED, because otherwise FAILED will match both cases.
+ elif stack['StackStatus'].endswith('_FAILED'):
+ ret.update({'changed': True, 'failed': True, 'output': 'Stack %s failed' % operation})
+ return ret
+ else:
+ # this can loop forever :/
+ time.sleep(5)
+ return {'failed': True, 'output': 'Failed for unknown reasons.'}
+
+
+def build_changeset_name(stack_params):
+ if 'ChangeSetName' in stack_params:
+ return stack_params['ChangeSetName']
+
+ json_params = json.dumps(stack_params, sort_keys=True)
+
+ return 'Ansible-{0}-{1}'.format(
+ stack_params['StackName'],
+ sha1(to_bytes(json_params, errors='surrogate_or_strict')).hexdigest()
+ )
+
+
+def check_mode_changeset(module, stack_params, cfn):
+ """Create a change set, describe it and delete it before returning check mode outputs."""
+ stack_params['ChangeSetName'] = build_changeset_name(stack_params)
+ # changesets don't accept ClientRequestToken parameters
+ stack_params.pop('ClientRequestToken', None)
+
+ try:
+ change_set = cfn.create_change_set(**stack_params)
+ for i in range(60): # total time 5 min
+ description = cfn.describe_change_set(ChangeSetName=change_set['Id'])
+ if description['Status'] in ('CREATE_COMPLETE', 'FAILED'):
+ break
+ time.sleep(5)
+ else:
+ # if the changeset doesn't finish in 5 mins, this `else` will trigger and fail
+ module.fail_json(msg="Failed to create change set %s" % stack_params['ChangeSetName'])
+
+ cfn.delete_change_set(ChangeSetName=change_set['Id'])
+
+ reason = description.get('StatusReason')
+
+ if description['Status'] == 'FAILED' and "didn't contain changes" in description['StatusReason']:
+ return {'changed': False, 'msg': reason, 'meta': description['StatusReason']}
+ return {'changed': True, 'msg': reason, 'meta': description['Changes']}
+
+ except (botocore.exceptions.ValidationError, botocore.exceptions.ClientError) as err:
+ error_msg = boto_exception(err)
+ module.fail_json(msg=error_msg, exception=traceback.format_exc())
+
+
+def get_stack_facts(cfn, stack_name):
+ try:
+ stack_response = cfn.describe_stacks(StackName=stack_name)
+ stack_info = stack_response['Stacks'][0]
+ except (botocore.exceptions.ValidationError, botocore.exceptions.ClientError) as err:
+ error_msg = boto_exception(err)
+ if 'does not exist' in error_msg:
+ # missing stack, don't bail.
+ return None
+
+ # other error, bail.
+ raise err
+
+ if stack_response and stack_response.get('Stacks', None):
+ stacks = stack_response['Stacks']
+ if len(stacks):
+ stack_info = stacks[0]
+
+ return stack_info
+
+
+def main():
+ argument_spec = ec2_argument_spec()
+ argument_spec.update(dict(
+ stack_name=dict(required=True),
+ template_parameters=dict(required=False, type='dict', default={}),
+ state=dict(default='present', choices=['present', 'absent']),
+ template=dict(default=None, required=False, type='path'),
+ notification_arns=dict(default=None, required=False),
+ stack_policy=dict(default=None, required=False),
+ disable_rollback=dict(default=False, type='bool'),
+ on_create_failure=dict(default=None, required=False, choices=['DO_NOTHING', 'ROLLBACK', 'DELETE']),
+ create_timeout=dict(default=None, type='int'),
+ template_url=dict(default=None, required=False),
+ template_body=dict(default=None, required=False),
+ template_format=dict(removed_in_version='2.14'),
+ create_changeset=dict(default=False, type='bool'),
+ changeset_name=dict(default=None, required=False),
+ role_arn=dict(default=None, required=False),
+ tags=dict(default=None, type='dict'),
+ termination_protection=dict(default=None, type='bool'),
+ events_limit=dict(default=200, type='int'),
+ backoff_retries=dict(type='int', default=10, required=False),
+ backoff_delay=dict(type='int', default=3, required=False),
+ backoff_max_delay=dict(type='int', default=30, required=False),
+ capabilities=dict(type='list', default=['CAPABILITY_IAM', 'CAPABILITY_NAMED_IAM'])
+ )
+ )
+
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ mutually_exclusive=[['template_url', 'template', 'template_body'],
+ ['disable_rollback', 'on_create_failure']],
+ supports_check_mode=True
+ )
+ if not HAS_BOTO3:
+ module.fail_json(msg='boto3 and botocore are required for this module')
+
+ invalid_capabilities = []
+ user_capabilities = module.params.get('capabilities')
+ for user_cap in user_capabilities:
+ if user_cap not in ['CAPABILITY_IAM', 'CAPABILITY_NAMED_IAM', 'CAPABILITY_AUTO_EXPAND']:
+ invalid_capabilities.append(user_cap)
+
+ if invalid_capabilities:
+ module.fail_json(msg="Specified capabilities are invalid : %r,"
+ " please check documentation for valid capabilities" % invalid_capabilities)
+
+ # collect the parameters that are passed to boto3. Keeps us from having so many scalars floating around.
+ stack_params = {
+ 'Capabilities': user_capabilities,
+ 'ClientRequestToken': to_native(uuid.uuid4()),
+ }
+ state = module.params['state']
+ stack_params['StackName'] = module.params['stack_name']
+
+ if module.params['template'] is not None:
+ with open(module.params['template'], 'r') as template_fh:
+ stack_params['TemplateBody'] = template_fh.read()
+ elif module.params['template_body'] is not None:
+ stack_params['TemplateBody'] = module.params['template_body']
+ elif module.params['template_url'] is not None:
+ stack_params['TemplateURL'] = module.params['template_url']
+
+ if module.params.get('notification_arns'):
+ stack_params['NotificationARNs'] = module.params['notification_arns'].split(',')
+ else:
+ stack_params['NotificationARNs'] = []
+
+ # can't check the policy when verifying.
+ if module.params['stack_policy'] is not None and not module.check_mode and not module.params['create_changeset']:
+ with open(module.params['stack_policy'], 'r') as stack_policy_fh:
+ stack_params['StackPolicyBody'] = stack_policy_fh.read()
+
+ template_parameters = module.params['template_parameters']
+
+ stack_params['Parameters'] = []
+ for k, v in template_parameters.items():
+ if isinstance(v, dict):
+ # set parameter based on a dict to allow additional CFN Parameter Attributes
+ param = dict(ParameterKey=k)
+
+ if 'value' in v:
+ param['ParameterValue'] = str(v['value'])
+
+ if 'use_previous_value' in v and bool(v['use_previous_value']):
+ param['UsePreviousValue'] = True
+ param.pop('ParameterValue', None)
+
+ stack_params['Parameters'].append(param)
+ else:
+ # allow default k/v configuration to set a template parameter
+ stack_params['Parameters'].append({'ParameterKey': k, 'ParameterValue': str(v)})
+
+ if isinstance(module.params.get('tags'), dict):
+ stack_params['Tags'] = ansible_dict_to_boto3_tag_list(module.params['tags'])
+
+ if module.params.get('role_arn'):
+ stack_params['RoleARN'] = module.params['role_arn']
+
+ result = {}
+
+ try:
+ region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True)
+ cfn = boto3_conn(module, conn_type='client', resource='cloudformation', region=region, endpoint=ec2_url, **aws_connect_kwargs)
+ except botocore.exceptions.NoCredentialsError as e:
+ module.fail_json(msg=boto_exception(e))
+
+ # Wrap the cloudformation client methods that this module uses with
+ # automatic backoff / retry for throttling error codes
+ backoff_wrapper = AWSRetry.jittered_backoff(
+ retries=module.params.get('backoff_retries'),
+ delay=module.params.get('backoff_delay'),
+ max_delay=module.params.get('backoff_max_delay')
+ )
+ cfn.describe_stack_events = backoff_wrapper(cfn.describe_stack_events)
+ cfn.create_stack = backoff_wrapper(cfn.create_stack)
+ cfn.list_change_sets = backoff_wrapper(cfn.list_change_sets)
+ cfn.create_change_set = backoff_wrapper(cfn.create_change_set)
+ cfn.update_stack = backoff_wrapper(cfn.update_stack)
+ cfn.describe_stacks = backoff_wrapper(cfn.describe_stacks)
+ cfn.list_stack_resources = backoff_wrapper(cfn.list_stack_resources)
+ cfn.delete_stack = backoff_wrapper(cfn.delete_stack)
+ if boto_supports_termination_protection(cfn):
+ cfn.update_termination_protection = backoff_wrapper(cfn.update_termination_protection)
+
+ stack_info = get_stack_facts(cfn, stack_params['StackName'])
+
+ if module.check_mode:
+ if state == 'absent' and stack_info:
+ module.exit_json(changed=True, msg='Stack would be deleted', meta=[])
+ elif state == 'absent' and not stack_info:
+ module.exit_json(changed=False, msg='Stack doesn\'t exist', meta=[])
+ elif state == 'present' and not stack_info:
+ module.exit_json(changed=True, msg='New stack would be created', meta=[])
+ else:
+ module.exit_json(**check_mode_changeset(module, stack_params, cfn))
+
+ if state == 'present':
+ if not stack_info:
+ result = create_stack(module, stack_params, cfn, module.params.get('events_limit'))
+ elif module.params.get('create_changeset'):
+ result = create_changeset(module, stack_params, cfn, module.params.get('events_limit'))
+ else:
+ if module.params.get('termination_protection') is not None:
+ update_termination_protection(module, cfn, stack_params['StackName'],
+ bool(module.params.get('termination_protection')))
+ result = update_stack(module, stack_params, cfn, module.params.get('events_limit'))
+
+ # format the stack output
+
+ stack = get_stack_facts(cfn, stack_params['StackName'])
+ if stack is not None:
+ if result.get('stack_outputs') is None:
+ # always define stack_outputs, but it may be empty
+ result['stack_outputs'] = {}
+ for output in stack.get('Outputs', []):
+ result['stack_outputs'][output['OutputKey']] = output['OutputValue']
+ stack_resources = []
+ reslist = cfn.list_stack_resources(StackName=stack_params['StackName'])
+ for res in reslist.get('StackResourceSummaries', []):
+ stack_resources.append({
+ "logical_resource_id": res['LogicalResourceId'],
+ "physical_resource_id": res.get('PhysicalResourceId', ''),
+ "resource_type": res['ResourceType'],
+ "last_updated_time": res['LastUpdatedTimestamp'],
+ "status": res['ResourceStatus'],
+ "status_reason": res.get('ResourceStatusReason') # can be blank, apparently
+ })
+ result['stack_resources'] = stack_resources
+
+ elif state == 'absent':
+ # absent state is different because of the way delete_stack works.
+ # problem is it it doesn't give an error if stack isn't found
+ # so must describe the stack first
+
+ try:
+ stack = get_stack_facts(cfn, stack_params['StackName'])
+ if not stack:
+ result = {'changed': False, 'output': 'Stack not found.'}
+ else:
+ if stack_params.get('RoleARN') is None:
+ cfn.delete_stack(StackName=stack_params['StackName'])
+ else:
+ cfn.delete_stack(StackName=stack_params['StackName'], RoleARN=stack_params['RoleARN'])
+ result = stack_operation(cfn, stack_params['StackName'], 'DELETE', module.params.get('events_limit'),
+ stack_params.get('ClientRequestToken', None))
+ except Exception as err:
+ module.fail_json(msg=boto_exception(err), exception=traceback.format_exc())
+
+ module.exit_json(**result)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/cloudformation_info.py b/test/support/integration/plugins/modules/cloudformation_info.py
new file mode 100644
index 00000000..ee2e5c17
--- /dev/null
+++ b/test/support/integration/plugins/modules/cloudformation_info.py
@@ -0,0 +1,355 @@
+#!/usr/bin/python
+# Copyright: Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: cloudformation_info
+short_description: Obtain information about an AWS CloudFormation stack
+description:
+ - Gets information about an AWS CloudFormation stack.
+ - This module was called C(cloudformation_facts) before Ansible 2.9, returning C(ansible_facts).
+ Note that the M(cloudformation_info) module no longer returns C(ansible_facts)!
+requirements:
+ - boto3 >= 1.0.0
+ - python >= 2.6
+version_added: "2.2"
+author:
+ - Justin Menga (@jmenga)
+ - Kevin Coming (@waffie1)
+options:
+ stack_name:
+ description:
+ - The name or id of the CloudFormation stack. Gathers information on all stacks by default.
+ type: str
+ all_facts:
+ description:
+ - Get all stack information for the stack.
+ type: bool
+ default: false
+ stack_events:
+ description:
+ - Get stack events for the stack.
+ type: bool
+ default: false
+ stack_template:
+ description:
+ - Get stack template body for the stack.
+ type: bool
+ default: false
+ stack_resources:
+ description:
+ - Get stack resources for the stack.
+ type: bool
+ default: false
+ stack_policy:
+ description:
+ - Get stack policy for the stack.
+ type: bool
+ default: false
+ stack_change_sets:
+ description:
+ - Get stack change sets for the stack
+ type: bool
+ default: false
+ version_added: '2.10'
+extends_documentation_fragment:
+ - aws
+ - ec2
+'''
+
+EXAMPLES = '''
+# Note: These examples do not set authentication details, see the AWS Guide for details.
+
+# Get summary information about a stack
+- cloudformation_info:
+ stack_name: my-cloudformation-stack
+ register: output
+
+- debug:
+ msg: "{{ output['cloudformation']['my-cloudformation-stack'] }}"
+
+# When the module is called as cloudformation_facts, return values are published
+# in ansible_facts['cloudformation'][<stack_name>] and can be used as follows.
+# Note that this is deprecated and will stop working in Ansible 2.13.
+
+- cloudformation_facts:
+ stack_name: my-cloudformation-stack
+
+- debug:
+ msg: "{{ ansible_facts['cloudformation']['my-cloudformation-stack'] }}"
+
+# Get stack outputs, when you have the stack name available as a fact
+- set_fact:
+ stack_name: my-awesome-stack
+
+- cloudformation_info:
+ stack_name: "{{ stack_name }}"
+ register: my_stack
+
+- debug:
+ msg: "{{ my_stack.cloudformation[stack_name].stack_outputs }}"
+
+# Get all stack information about a stack
+- cloudformation_info:
+ stack_name: my-cloudformation-stack
+ all_facts: true
+
+# Get stack resource and stack policy information about a stack
+- cloudformation_info:
+ stack_name: my-cloudformation-stack
+ stack_resources: true
+ stack_policy: true
+
+# Fail if the stack doesn't exist
+- name: try to get facts about a stack but fail if it doesn't exist
+ cloudformation_info:
+ stack_name: nonexistent-stack
+ all_facts: yes
+ failed_when: cloudformation['nonexistent-stack'] is undefined
+'''
+
+RETURN = '''
+stack_description:
+ description: Summary facts about the stack
+ returned: if the stack exists
+ type: dict
+stack_outputs:
+ description: Dictionary of stack outputs keyed by the value of each output 'OutputKey' parameter and corresponding value of each
+ output 'OutputValue' parameter
+ returned: if the stack exists
+ type: dict
+ sample:
+ ApplicationDatabaseName: dazvlpr01xj55a.ap-southeast-2.rds.amazonaws.com
+stack_parameters:
+ description: Dictionary of stack parameters keyed by the value of each parameter 'ParameterKey' parameter and corresponding value of
+ each parameter 'ParameterValue' parameter
+ returned: if the stack exists
+ type: dict
+ sample:
+ DatabaseEngine: mysql
+ DatabasePassword: "***"
+stack_events:
+ description: All stack events for the stack
+ returned: only if all_facts or stack_events is true and the stack exists
+ type: list
+stack_policy:
+ description: Describes the stack policy for the stack
+ returned: only if all_facts or stack_policy is true and the stack exists
+ type: dict
+stack_template:
+ description: Describes the stack template for the stack
+ returned: only if all_facts or stack_template is true and the stack exists
+ type: dict
+stack_resource_list:
+ description: Describes stack resources for the stack
+ returned: only if all_facts or stack_resourses is true and the stack exists
+ type: list
+stack_resources:
+ description: Dictionary of stack resources keyed by the value of each resource 'LogicalResourceId' parameter and corresponding value of each
+ resource 'PhysicalResourceId' parameter
+ returned: only if all_facts or stack_resourses is true and the stack exists
+ type: dict
+ sample:
+ AutoScalingGroup: "dev-someapp-AutoscalingGroup-1SKEXXBCAN0S7"
+ AutoScalingSecurityGroup: "sg-abcd1234"
+ ApplicationDatabase: "dazvlpr01xj55a"
+stack_change_sets:
+ description: A list of stack change sets. Each item in the list represents the details of a specific changeset
+
+ returned: only if all_facts or stack_change_sets is true and the stack exists
+ type: list
+'''
+
+import json
+import traceback
+
+from functools import partial
+from ansible.module_utils._text import to_native
+from ansible.module_utils.aws.core import AnsibleAWSModule
+from ansible.module_utils.ec2 import (camel_dict_to_snake_dict, AWSRetry, boto3_tag_list_to_ansible_dict)
+
+try:
+ import botocore
+except ImportError:
+ pass # handled by AnsibleAWSModule
+
+
+class CloudFormationServiceManager:
+ """Handles CloudFormation Services"""
+
+ def __init__(self, module):
+ self.module = module
+ self.client = module.client('cloudformation')
+
+ @AWSRetry.exponential_backoff(retries=5, delay=5)
+ def describe_stacks_with_backoff(self, **kwargs):
+ paginator = self.client.get_paginator('describe_stacks')
+ return paginator.paginate(**kwargs).build_full_result()['Stacks']
+
+ def describe_stacks(self, stack_name=None):
+ try:
+ kwargs = {'StackName': stack_name} if stack_name else {}
+ response = self.describe_stacks_with_backoff(**kwargs)
+ if response is not None:
+ return response
+ self.module.fail_json(msg="Error describing stack(s) - an empty response was returned")
+ except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e:
+ if 'does not exist' in e.response['Error']['Message']:
+ # missing stack, don't bail.
+ return {}
+ self.module.fail_json_aws(e, msg="Error describing stack " + stack_name)
+
+ @AWSRetry.exponential_backoff(retries=5, delay=5)
+ def list_stack_resources_with_backoff(self, stack_name):
+ paginator = self.client.get_paginator('list_stack_resources')
+ return paginator.paginate(StackName=stack_name).build_full_result()['StackResourceSummaries']
+
+ def list_stack_resources(self, stack_name):
+ try:
+ return self.list_stack_resources_with_backoff(stack_name)
+ except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e:
+ self.module.fail_json_aws(e, msg="Error listing stack resources for stack " + stack_name)
+
+ @AWSRetry.exponential_backoff(retries=5, delay=5)
+ def describe_stack_events_with_backoff(self, stack_name):
+ paginator = self.client.get_paginator('describe_stack_events')
+ return paginator.paginate(StackName=stack_name).build_full_result()['StackEvents']
+
+ def describe_stack_events(self, stack_name):
+ try:
+ return self.describe_stack_events_with_backoff(stack_name)
+ except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e:
+ self.module.fail_json_aws(e, msg="Error listing stack events for stack " + stack_name)
+
+ @AWSRetry.exponential_backoff(retries=5, delay=5)
+ def list_stack_change_sets_with_backoff(self, stack_name):
+ paginator = self.client.get_paginator('list_change_sets')
+ return paginator.paginate(StackName=stack_name).build_full_result()['Summaries']
+
+ @AWSRetry.exponential_backoff(retries=5, delay=5)
+ def describe_stack_change_set_with_backoff(self, **kwargs):
+ paginator = self.client.get_paginator('describe_change_set')
+ return paginator.paginate(**kwargs).build_full_result()
+
+ def describe_stack_change_sets(self, stack_name):
+ changes = []
+ try:
+ change_sets = self.list_stack_change_sets_with_backoff(stack_name)
+ for item in change_sets:
+ changes.append(self.describe_stack_change_set_with_backoff(
+ StackName=stack_name,
+ ChangeSetName=item['ChangeSetName']))
+ return changes
+ except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e:
+ self.module.fail_json_aws(e, msg="Error describing stack change sets for stack " + stack_name)
+
+ @AWSRetry.exponential_backoff(retries=5, delay=5)
+ def get_stack_policy_with_backoff(self, stack_name):
+ return self.client.get_stack_policy(StackName=stack_name)
+
+ def get_stack_policy(self, stack_name):
+ try:
+ response = self.get_stack_policy_with_backoff(stack_name)
+ stack_policy = response.get('StackPolicyBody')
+ if stack_policy:
+ return json.loads(stack_policy)
+ return dict()
+ except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e:
+ self.module.fail_json_aws(e, msg="Error getting stack policy for stack " + stack_name)
+
+ @AWSRetry.exponential_backoff(retries=5, delay=5)
+ def get_template_with_backoff(self, stack_name):
+ return self.client.get_template(StackName=stack_name)
+
+ def get_template(self, stack_name):
+ try:
+ response = self.get_template_with_backoff(stack_name)
+ return response.get('TemplateBody')
+ except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e:
+ self.module.fail_json_aws(e, msg="Error getting stack template for stack " + stack_name)
+
+
+def to_dict(items, key, value):
+ ''' Transforms a list of items to a Key/Value dictionary '''
+ if items:
+ return dict(zip([i.get(key) for i in items], [i.get(value) for i in items]))
+ else:
+ return dict()
+
+
+def main():
+ argument_spec = dict(
+ stack_name=dict(),
+ all_facts=dict(required=False, default=False, type='bool'),
+ stack_policy=dict(required=False, default=False, type='bool'),
+ stack_events=dict(required=False, default=False, type='bool'),
+ stack_resources=dict(required=False, default=False, type='bool'),
+ stack_template=dict(required=False, default=False, type='bool'),
+ stack_change_sets=dict(required=False, default=False, type='bool'),
+ )
+ module = AnsibleAWSModule(argument_spec=argument_spec, supports_check_mode=True)
+
+ is_old_facts = module._name == 'cloudformation_facts'
+ if is_old_facts:
+ module.deprecate("The 'cloudformation_facts' module has been renamed to 'cloudformation_info', "
+ "and the renamed one no longer returns ansible_facts",
+ version='2.13', collection_name='ansible.builtin')
+
+ service_mgr = CloudFormationServiceManager(module)
+
+ if is_old_facts:
+ result = {'ansible_facts': {'cloudformation': {}}}
+ else:
+ result = {'cloudformation': {}}
+
+ for stack_description in service_mgr.describe_stacks(module.params.get('stack_name')):
+ facts = {'stack_description': stack_description}
+ stack_name = stack_description.get('StackName')
+
+ # Create stack output and stack parameter dictionaries
+ if facts['stack_description']:
+ facts['stack_outputs'] = to_dict(facts['stack_description'].get('Outputs'), 'OutputKey', 'OutputValue')
+ facts['stack_parameters'] = to_dict(facts['stack_description'].get('Parameters'),
+ 'ParameterKey', 'ParameterValue')
+ facts['stack_tags'] = boto3_tag_list_to_ansible_dict(facts['stack_description'].get('Tags'))
+
+ # Create optional stack outputs
+ all_facts = module.params.get('all_facts')
+ if all_facts or module.params.get('stack_resources'):
+ facts['stack_resource_list'] = service_mgr.list_stack_resources(stack_name)
+ facts['stack_resources'] = to_dict(facts.get('stack_resource_list'),
+ 'LogicalResourceId', 'PhysicalResourceId')
+ if all_facts or module.params.get('stack_template'):
+ facts['stack_template'] = service_mgr.get_template(stack_name)
+ if all_facts or module.params.get('stack_policy'):
+ facts['stack_policy'] = service_mgr.get_stack_policy(stack_name)
+ if all_facts or module.params.get('stack_events'):
+ facts['stack_events'] = service_mgr.describe_stack_events(stack_name)
+ if all_facts or module.params.get('stack_change_sets'):
+ facts['stack_change_sets'] = service_mgr.describe_stack_change_sets(stack_name)
+
+ if is_old_facts:
+ result['ansible_facts']['cloudformation'][stack_name] = facts
+ else:
+ result['cloudformation'][stack_name] = camel_dict_to_snake_dict(facts, ignore_list=('stack_outputs',
+ 'stack_parameters',
+ 'stack_policy',
+ 'stack_resources',
+ 'stack_tags',
+ 'stack_template'))
+
+ module.exit_json(changed=False, **result)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/deploy_helper.py b/test/support/integration/plugins/modules/deploy_helper.py
new file mode 100644
index 00000000..38594dde
--- /dev/null
+++ b/test/support/integration/plugins/modules/deploy_helper.py
@@ -0,0 +1,521 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# (c) 2014, Jasper N. Brouwer <jasper@nerdsweide.nl>
+# (c) 2014, Ramon de la Fuente <ramon@delafuente.nl>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: deploy_helper
+version_added: "2.0"
+author: "Ramon de la Fuente (@ramondelafuente)"
+short_description: Manages some of the steps common in deploying projects.
+description:
+ - The Deploy Helper manages some of the steps common in deploying software.
+ It creates a folder structure, manages a symlink for the current release
+ and cleans up old releases.
+ - "Running it with the C(state=query) or C(state=present) will return the C(deploy_helper) fact.
+ C(project_path), whatever you set in the path parameter,
+ C(current_path), the path to the symlink that points to the active release,
+ C(releases_path), the path to the folder to keep releases in,
+ C(shared_path), the path to the folder to keep shared resources in,
+ C(unfinished_filename), the file to check for to recognize unfinished builds,
+ C(previous_release), the release the 'current' symlink is pointing to,
+ C(previous_release_path), the full path to the 'current' symlink target,
+ C(new_release), either the 'release' parameter or a generated timestamp,
+ C(new_release_path), the path to the new release folder (not created by the module)."
+
+options:
+ path:
+ required: True
+ aliases: ['dest']
+ description:
+ - the root path of the project. Alias I(dest).
+ Returned in the C(deploy_helper.project_path) fact.
+
+ state:
+ description:
+ - the state of the project.
+ C(query) will only gather facts,
+ C(present) will create the project I(root) folder, and in it the I(releases) and I(shared) folders,
+ C(finalize) will remove the unfinished_filename file, create a symlink to the newly
+ deployed release and optionally clean old releases,
+ C(clean) will remove failed & old releases,
+ C(absent) will remove the project folder (synonymous to the M(file) module with C(state=absent))
+ choices: [ present, finalize, absent, clean, query ]
+ default: present
+
+ release:
+ description:
+ - the release version that is being deployed. Defaults to a timestamp format %Y%m%d%H%M%S (i.e. '20141119223359').
+ This parameter is optional during C(state=present), but needs to be set explicitly for C(state=finalize).
+ You can use the generated fact C(release={{ deploy_helper.new_release }}).
+
+ releases_path:
+ description:
+ - the name of the folder that will hold the releases. This can be relative to C(path) or absolute.
+ Returned in the C(deploy_helper.releases_path) fact.
+ default: releases
+
+ shared_path:
+ description:
+ - the name of the folder that will hold the shared resources. This can be relative to C(path) or absolute.
+ If this is set to an empty string, no shared folder will be created.
+ Returned in the C(deploy_helper.shared_path) fact.
+ default: shared
+
+ current_path:
+ description:
+ - the name of the symlink that is created when the deploy is finalized. Used in C(finalize) and C(clean).
+ Returned in the C(deploy_helper.current_path) fact.
+ default: current
+
+ unfinished_filename:
+ description:
+ - the name of the file that indicates a deploy has not finished. All folders in the releases_path that
+ contain this file will be deleted on C(state=finalize) with clean=True, or C(state=clean). This file is
+ automatically deleted from the I(new_release_path) during C(state=finalize).
+ default: DEPLOY_UNFINISHED
+
+ clean:
+ description:
+ - Whether to run the clean procedure in case of C(state=finalize).
+ type: bool
+ default: 'yes'
+
+ keep_releases:
+ description:
+ - the number of old releases to keep when cleaning. Used in C(finalize) and C(clean). Any unfinished builds
+ will be deleted first, so only correct releases will count. The current version will not count.
+ default: 5
+
+notes:
+ - Facts are only returned for C(state=query) and C(state=present). If you use both, you should pass any overridden
+ parameters to both calls, otherwise the second call will overwrite the facts of the first one.
+ - When using C(state=clean), the releases are ordered by I(creation date). You should be able to switch to a
+ new naming strategy without problems.
+ - Because of the default behaviour of generating the I(new_release) fact, this module will not be idempotent
+ unless you pass your own release name with C(release). Due to the nature of deploying software, this should not
+ be much of a problem.
+'''
+
+EXAMPLES = '''
+
+# General explanation, starting with an example folder structure for a project:
+
+# root:
+# releases:
+# - 20140415234508
+# - 20140415235146
+# - 20140416082818
+#
+# shared:
+# - sessions
+# - uploads
+#
+# current: releases/20140416082818
+
+
+# The 'releases' folder holds all the available releases. A release is a complete build of the application being
+# deployed. This can be a clone of a repository for example, or a sync of a local folder on your filesystem.
+# Having timestamped folders is one way of having distinct releases, but you could choose your own strategy like
+# git tags or commit hashes.
+#
+# During a deploy, a new folder should be created in the releases folder and any build steps required should be
+# performed. Once the new build is ready, the deploy procedure is 'finalized' by replacing the 'current' symlink
+# with a link to this build.
+#
+# The 'shared' folder holds any resource that is shared between releases. Examples of this are web-server
+# session files, or files uploaded by users of your application. It's quite common to have symlinks from a release
+# folder pointing to a shared/subfolder, and creating these links would be automated as part of the build steps.
+#
+# The 'current' symlink points to one of the releases. Probably the latest one, unless a deploy is in progress.
+# The web-server's root for the project will go through this symlink, so the 'downtime' when switching to a new
+# release is reduced to the time it takes to switch the link.
+#
+# To distinguish between successful builds and unfinished ones, a file can be placed in the folder of the release
+# that is currently in progress. The existence of this file will mark it as unfinished, and allow an automated
+# procedure to remove it during cleanup.
+
+
+# Typical usage
+- name: Initialize the deploy root and gather facts
+ deploy_helper:
+ path: /path/to/root
+- name: Clone the project to the new release folder
+ git:
+ repo: git://foosball.example.org/path/to/repo.git
+ dest: '{{ deploy_helper.new_release_path }}'
+ version: v1.1.1
+- name: Add an unfinished file, to allow cleanup on successful finalize
+ file:
+ path: '{{ deploy_helper.new_release_path }}/{{ deploy_helper.unfinished_filename }}'
+ state: touch
+- name: Perform some build steps, like running your dependency manager for example
+ composer:
+ command: install
+ working_dir: '{{ deploy_helper.new_release_path }}'
+- name: Create some folders in the shared folder
+ file:
+ path: '{{ deploy_helper.shared_path }}/{{ item }}'
+ state: directory
+ with_items:
+ - sessions
+ - uploads
+- name: Add symlinks from the new release to the shared folder
+ file:
+ path: '{{ deploy_helper.new_release_path }}/{{ item.path }}'
+ src: '{{ deploy_helper.shared_path }}/{{ item.src }}'
+ state: link
+ with_items:
+ - path: app/sessions
+ src: sessions
+ - path: web/uploads
+ src: uploads
+- name: Finalize the deploy, removing the unfinished file and switching the symlink
+ deploy_helper:
+ path: /path/to/root
+ release: '{{ deploy_helper.new_release }}'
+ state: finalize
+
+# Retrieving facts before running a deploy
+- name: Run 'state=query' to gather facts without changing anything
+ deploy_helper:
+ path: /path/to/root
+ state: query
+# Remember to set the 'release' parameter when you actually call 'state=present' later
+- name: Initialize the deploy root
+ deploy_helper:
+ path: /path/to/root
+ release: '{{ deploy_helper.new_release }}'
+ state: present
+
+# all paths can be absolute or relative (to the 'path' parameter)
+- deploy_helper:
+ path: /path/to/root
+ releases_path: /var/www/project/releases
+ shared_path: /var/www/shared
+ current_path: /var/www/active
+
+# Using your own naming strategy for releases (a version tag in this case):
+- deploy_helper:
+ path: /path/to/root
+ release: v1.1.1
+ state: present
+- deploy_helper:
+ path: /path/to/root
+ release: '{{ deploy_helper.new_release }}'
+ state: finalize
+
+# Using a different unfinished_filename:
+- deploy_helper:
+ path: /path/to/root
+ unfinished_filename: README.md
+ release: '{{ deploy_helper.new_release }}'
+ state: finalize
+
+# Postponing the cleanup of older builds:
+- deploy_helper:
+ path: /path/to/root
+ release: '{{ deploy_helper.new_release }}'
+ state: finalize
+ clean: False
+- deploy_helper:
+ path: /path/to/root
+ state: clean
+# Or running the cleanup ahead of the new deploy
+- deploy_helper:
+ path: /path/to/root
+ state: clean
+- deploy_helper:
+ path: /path/to/root
+ state: present
+
+# Keeping more old releases:
+- deploy_helper:
+ path: /path/to/root
+ release: '{{ deploy_helper.new_release }}'
+ state: finalize
+ keep_releases: 10
+# Or, if you use 'clean=false' on finalize:
+- deploy_helper:
+ path: /path/to/root
+ state: clean
+ keep_releases: 10
+
+# Removing the entire project root folder
+- deploy_helper:
+ path: /path/to/root
+ state: absent
+
+# Debugging the facts returned by the module
+- deploy_helper:
+ path: /path/to/root
+- debug:
+ var: deploy_helper
+'''
+import os
+import shutil
+import time
+import traceback
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils._text import to_native
+
+
+class DeployHelper(object):
+
+ def __init__(self, module):
+ self.module = module
+ self.file_args = module.load_file_common_arguments(module.params)
+
+ self.clean = module.params['clean']
+ self.current_path = module.params['current_path']
+ self.keep_releases = module.params['keep_releases']
+ self.path = module.params['path']
+ self.release = module.params['release']
+ self.releases_path = module.params['releases_path']
+ self.shared_path = module.params['shared_path']
+ self.state = module.params['state']
+ self.unfinished_filename = module.params['unfinished_filename']
+
+ def gather_facts(self):
+ current_path = os.path.join(self.path, self.current_path)
+ releases_path = os.path.join(self.path, self.releases_path)
+ if self.shared_path:
+ shared_path = os.path.join(self.path, self.shared_path)
+ else:
+ shared_path = None
+
+ previous_release, previous_release_path = self._get_last_release(current_path)
+
+ if not self.release and (self.state == 'query' or self.state == 'present'):
+ self.release = time.strftime("%Y%m%d%H%M%S")
+
+ if self.release:
+ new_release_path = os.path.join(releases_path, self.release)
+ else:
+ new_release_path = None
+
+ return {
+ 'project_path': self.path,
+ 'current_path': current_path,
+ 'releases_path': releases_path,
+ 'shared_path': shared_path,
+ 'previous_release': previous_release,
+ 'previous_release_path': previous_release_path,
+ 'new_release': self.release,
+ 'new_release_path': new_release_path,
+ 'unfinished_filename': self.unfinished_filename
+ }
+
+ def delete_path(self, path):
+ if not os.path.lexists(path):
+ return False
+
+ if not os.path.isdir(path):
+ self.module.fail_json(msg="%s exists but is not a directory" % path)
+
+ if not self.module.check_mode:
+ try:
+ shutil.rmtree(path, ignore_errors=False)
+ except Exception as e:
+ self.module.fail_json(msg="rmtree failed: %s" % to_native(e), exception=traceback.format_exc())
+
+ return True
+
+ def create_path(self, path):
+ changed = False
+
+ if not os.path.lexists(path):
+ changed = True
+ if not self.module.check_mode:
+ os.makedirs(path)
+
+ elif not os.path.isdir(path):
+ self.module.fail_json(msg="%s exists but is not a directory" % path)
+
+ changed += self.module.set_directory_attributes_if_different(self._get_file_args(path), changed)
+
+ return changed
+
+ def check_link(self, path):
+ if os.path.lexists(path):
+ if not os.path.islink(path):
+ self.module.fail_json(msg="%s exists but is not a symbolic link" % path)
+
+ def create_link(self, source, link_name):
+ changed = False
+
+ if os.path.islink(link_name):
+ norm_link = os.path.normpath(os.path.realpath(link_name))
+ norm_source = os.path.normpath(os.path.realpath(source))
+ if norm_link == norm_source:
+ changed = False
+ else:
+ changed = True
+ if not self.module.check_mode:
+ if not os.path.lexists(source):
+ self.module.fail_json(msg="the symlink target %s doesn't exists" % source)
+ tmp_link_name = link_name + '.' + self.unfinished_filename
+ if os.path.islink(tmp_link_name):
+ os.unlink(tmp_link_name)
+ os.symlink(source, tmp_link_name)
+ os.rename(tmp_link_name, link_name)
+ else:
+ changed = True
+ if not self.module.check_mode:
+ os.symlink(source, link_name)
+
+ return changed
+
+ def remove_unfinished_file(self, new_release_path):
+ changed = False
+ unfinished_file_path = os.path.join(new_release_path, self.unfinished_filename)
+ if os.path.lexists(unfinished_file_path):
+ changed = True
+ if not self.module.check_mode:
+ os.remove(unfinished_file_path)
+
+ return changed
+
+ def remove_unfinished_builds(self, releases_path):
+ changes = 0
+
+ for release in os.listdir(releases_path):
+ if os.path.isfile(os.path.join(releases_path, release, self.unfinished_filename)):
+ if self.module.check_mode:
+ changes += 1
+ else:
+ changes += self.delete_path(os.path.join(releases_path, release))
+
+ return changes
+
+ def remove_unfinished_link(self, path):
+ changed = False
+
+ tmp_link_name = os.path.join(path, self.release + '.' + self.unfinished_filename)
+ if not self.module.check_mode and os.path.exists(tmp_link_name):
+ changed = True
+ os.remove(tmp_link_name)
+
+ return changed
+
+ def cleanup(self, releases_path, reserve_version):
+ changes = 0
+
+ if os.path.lexists(releases_path):
+ releases = [f for f in os.listdir(releases_path) if os.path.isdir(os.path.join(releases_path, f))]
+ try:
+ releases.remove(reserve_version)
+ except ValueError:
+ pass
+
+ if not self.module.check_mode:
+ releases.sort(key=lambda x: os.path.getctime(os.path.join(releases_path, x)), reverse=True)
+ for release in releases[self.keep_releases:]:
+ changes += self.delete_path(os.path.join(releases_path, release))
+ elif len(releases) > self.keep_releases:
+ changes += (len(releases) - self.keep_releases)
+
+ return changes
+
+ def _get_file_args(self, path):
+ file_args = self.file_args.copy()
+ file_args['path'] = path
+ return file_args
+
+ def _get_last_release(self, current_path):
+ previous_release = None
+ previous_release_path = None
+
+ if os.path.lexists(current_path):
+ previous_release_path = os.path.realpath(current_path)
+ previous_release = os.path.basename(previous_release_path)
+
+ return previous_release, previous_release_path
+
+
+def main():
+
+ module = AnsibleModule(
+ argument_spec=dict(
+ path=dict(aliases=['dest'], required=True, type='path'),
+ release=dict(required=False, type='str', default=None),
+ releases_path=dict(required=False, type='str', default='releases'),
+ shared_path=dict(required=False, type='path', default='shared'),
+ current_path=dict(required=False, type='path', default='current'),
+ keep_releases=dict(required=False, type='int', default=5),
+ clean=dict(required=False, type='bool', default=True),
+ unfinished_filename=dict(required=False, type='str', default='DEPLOY_UNFINISHED'),
+ state=dict(required=False, choices=['present', 'absent', 'clean', 'finalize', 'query'], default='present')
+ ),
+ add_file_common_args=True,
+ supports_check_mode=True
+ )
+
+ deploy_helper = DeployHelper(module)
+ facts = deploy_helper.gather_facts()
+
+ result = {
+ 'state': deploy_helper.state
+ }
+
+ changes = 0
+
+ if deploy_helper.state == 'query':
+ result['ansible_facts'] = {'deploy_helper': facts}
+
+ elif deploy_helper.state == 'present':
+ deploy_helper.check_link(facts['current_path'])
+ changes += deploy_helper.create_path(facts['project_path'])
+ changes += deploy_helper.create_path(facts['releases_path'])
+ if deploy_helper.shared_path:
+ changes += deploy_helper.create_path(facts['shared_path'])
+
+ result['ansible_facts'] = {'deploy_helper': facts}
+
+ elif deploy_helper.state == 'finalize':
+ if not deploy_helper.release:
+ module.fail_json(msg="'release' is a required parameter for state=finalize (try the 'deploy_helper.new_release' fact)")
+ if deploy_helper.keep_releases <= 0:
+ module.fail_json(msg="'keep_releases' should be at least 1")
+
+ changes += deploy_helper.remove_unfinished_file(facts['new_release_path'])
+ changes += deploy_helper.create_link(facts['new_release_path'], facts['current_path'])
+ if deploy_helper.clean:
+ changes += deploy_helper.remove_unfinished_link(facts['project_path'])
+ changes += deploy_helper.remove_unfinished_builds(facts['releases_path'])
+ changes += deploy_helper.cleanup(facts['releases_path'], facts['new_release'])
+
+ elif deploy_helper.state == 'clean':
+ changes += deploy_helper.remove_unfinished_link(facts['project_path'])
+ changes += deploy_helper.remove_unfinished_builds(facts['releases_path'])
+ changes += deploy_helper.cleanup(facts['releases_path'], facts['new_release'])
+
+ elif deploy_helper.state == 'absent':
+ # destroy the facts
+ result['ansible_facts'] = {'deploy_helper': []}
+ changes += deploy_helper.delete_path(facts['project_path'])
+
+ if changes > 0:
+ result['changed'] = True
+ else:
+ result['changed'] = False
+
+ module.exit_json(**result)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/docker_swarm.py b/test/support/integration/plugins/modules/docker_swarm.py
new file mode 100644
index 00000000..a2c076c5
--- /dev/null
+++ b/test/support/integration/plugins/modules/docker_swarm.py
@@ -0,0 +1,681 @@
+#!/usr/bin/python
+
+# Copyright 2016 Red Hat | Ansible
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = '''
+---
+module: docker_swarm
+short_description: Manage Swarm cluster
+version_added: "2.7"
+description:
+ - Create a new Swarm cluster.
+ - Add/Remove nodes or managers to an existing cluster.
+options:
+ advertise_addr:
+ description:
+ - Externally reachable address advertised to other nodes.
+ - This can either be an address/port combination
+ in the form C(192.168.1.1:4567), or an interface followed by a
+ port number, like C(eth0:4567).
+ - If the port number is omitted,
+ the port number from the listen address is used.
+ - If I(advertise_addr) is not specified, it will be automatically
+ detected when possible.
+ - Only used when swarm is initialised or joined. Because of this it's not
+ considered for idempotency checking.
+ type: str
+ default_addr_pool:
+ description:
+ - Default address pool in CIDR format.
+ - Only used when swarm is initialised. Because of this it's not considered
+ for idempotency checking.
+ - Requires API version >= 1.39.
+ type: list
+ elements: str
+ version_added: "2.8"
+ subnet_size:
+ description:
+ - Default address pool subnet mask length.
+ - Only used when swarm is initialised. Because of this it's not considered
+ for idempotency checking.
+ - Requires API version >= 1.39.
+ type: int
+ version_added: "2.8"
+ listen_addr:
+ description:
+ - Listen address used for inter-manager communication.
+ - This can either be an address/port combination in the form
+ C(192.168.1.1:4567), or an interface followed by a port number,
+ like C(eth0:4567).
+ - If the port number is omitted, the default swarm listening port
+ is used.
+ - Only used when swarm is initialised or joined. Because of this it's not
+ considered for idempotency checking.
+ type: str
+ default: 0.0.0.0:2377
+ force:
+ description:
+ - Use with state C(present) to force creating a new Swarm, even if already part of one.
+ - Use with state C(absent) to Leave the swarm even if this node is a manager.
+ type: bool
+ default: no
+ state:
+ description:
+ - Set to C(present), to create/update a new cluster.
+ - Set to C(join), to join an existing cluster.
+ - Set to C(absent), to leave an existing cluster.
+ - Set to C(remove), to remove an absent node from the cluster.
+ Note that removing requires Docker SDK for Python >= 2.4.0.
+ - Set to C(inspect) to display swarm informations.
+ type: str
+ default: present
+ choices:
+ - present
+ - join
+ - absent
+ - remove
+ - inspect
+ node_id:
+ description:
+ - Swarm id of the node to remove.
+ - Used with I(state=remove).
+ type: str
+ join_token:
+ description:
+ - Swarm token used to join a swarm cluster.
+ - Used with I(state=join).
+ type: str
+ remote_addrs:
+ description:
+ - Remote address of one or more manager nodes of an existing Swarm to connect to.
+ - Used with I(state=join).
+ type: list
+ elements: str
+ task_history_retention_limit:
+ description:
+ - Maximum number of tasks history stored.
+ - Docker default value is C(5).
+ type: int
+ snapshot_interval:
+ description:
+ - Number of logs entries between snapshot.
+ - Docker default value is C(10000).
+ type: int
+ keep_old_snapshots:
+ description:
+ - Number of snapshots to keep beyond the current snapshot.
+ - Docker default value is C(0).
+ type: int
+ log_entries_for_slow_followers:
+ description:
+ - Number of log entries to keep around to sync up slow followers after a snapshot is created.
+ type: int
+ heartbeat_tick:
+ description:
+ - Amount of ticks (in seconds) between each heartbeat.
+ - Docker default value is C(1s).
+ type: int
+ election_tick:
+ description:
+ - Amount of ticks (in seconds) needed without a leader to trigger a new election.
+ - Docker default value is C(10s).
+ type: int
+ dispatcher_heartbeat_period:
+ description:
+ - The delay for an agent to send a heartbeat to the dispatcher.
+ - Docker default value is C(5s).
+ type: int
+ node_cert_expiry:
+ description:
+ - Automatic expiry for nodes certificates.
+ - Docker default value is C(3months).
+ type: int
+ name:
+ description:
+ - The name of the swarm.
+ type: str
+ labels:
+ description:
+ - User-defined key/value metadata.
+ - Label operations in this module apply to the docker swarm cluster.
+ Use M(docker_node) module to add/modify/remove swarm node labels.
+ - Requires API version >= 1.32.
+ type: dict
+ signing_ca_cert:
+ description:
+ - The desired signing CA certificate for all swarm node TLS leaf certificates, in PEM format.
+ - This must not be a path to a certificate, but the contents of the certificate.
+ - Requires API version >= 1.30.
+ type: str
+ signing_ca_key:
+ description:
+ - The desired signing CA key for all swarm node TLS leaf certificates, in PEM format.
+ - This must not be a path to a key, but the contents of the key.
+ - Requires API version >= 1.30.
+ type: str
+ ca_force_rotate:
+ description:
+ - An integer whose purpose is to force swarm to generate a new signing CA certificate and key,
+ if none have been specified.
+ - Docker default value is C(0).
+ - Requires API version >= 1.30.
+ type: int
+ autolock_managers:
+ description:
+ - If set, generate a key and use it to lock data stored on the managers.
+ - Docker default value is C(no).
+ - M(docker_swarm_info) can be used to retrieve the unlock key.
+ type: bool
+ rotate_worker_token:
+ description: Rotate the worker join token.
+ type: bool
+ default: no
+ rotate_manager_token:
+ description: Rotate the manager join token.
+ type: bool
+ default: no
+extends_documentation_fragment:
+ - docker
+ - docker.docker_py_1_documentation
+requirements:
+ - "L(Docker SDK for Python,https://docker-py.readthedocs.io/en/stable/) >= 1.10.0 (use L(docker-py,https://pypi.org/project/docker-py/) for Python 2.6)"
+ - Docker API >= 1.25
+author:
+ - Thierry Bouvet (@tbouvet)
+ - Piotr Wojciechowski (@WojciechowskiPiotr)
+'''
+
+EXAMPLES = '''
+
+- name: Init a new swarm with default parameters
+ docker_swarm:
+ state: present
+
+- name: Update swarm configuration
+ docker_swarm:
+ state: present
+ election_tick: 5
+
+- name: Add nodes
+ docker_swarm:
+ state: join
+ advertise_addr: 192.168.1.2
+ join_token: SWMTKN-1--xxxxx
+ remote_addrs: [ '192.168.1.1:2377' ]
+
+- name: Leave swarm for a node
+ docker_swarm:
+ state: absent
+
+- name: Remove a swarm manager
+ docker_swarm:
+ state: absent
+ force: true
+
+- name: Remove node from swarm
+ docker_swarm:
+ state: remove
+ node_id: mynode
+
+- name: Inspect swarm
+ docker_swarm:
+ state: inspect
+ register: swarm_info
+'''
+
+RETURN = '''
+swarm_facts:
+ description: Informations about swarm.
+ returned: success
+ type: dict
+ contains:
+ JoinTokens:
+ description: Tokens to connect to the Swarm.
+ returned: success
+ type: dict
+ contains:
+ Worker:
+ description: Token to create a new *worker* node
+ returned: success
+ type: str
+ example: SWMTKN-1--xxxxx
+ Manager:
+ description: Token to create a new *manager* node
+ returned: success
+ type: str
+ example: SWMTKN-1--xxxxx
+ UnlockKey:
+ description: The swarm unlock-key if I(autolock_managers) is C(true).
+ returned: on success if I(autolock_managers) is C(true)
+ and swarm is initialised, or if I(autolock_managers) has changed.
+ type: str
+ example: SWMKEY-1-xxx
+
+actions:
+ description: Provides the actions done on the swarm.
+ returned: when action failed.
+ type: list
+ elements: str
+ example: "['This cluster is already a swarm cluster']"
+
+'''
+
+import json
+import traceback
+
+try:
+ from docker.errors import DockerException, APIError
+except ImportError:
+ # missing Docker SDK for Python handled in ansible.module_utils.docker.common
+ pass
+
+from ansible.module_utils.docker.common import (
+ DockerBaseClass,
+ DifferenceTracker,
+ RequestException,
+)
+
+from ansible.module_utils.docker.swarm import AnsibleDockerSwarmClient
+
+from ansible.module_utils._text import to_native
+
+
+class TaskParameters(DockerBaseClass):
+ def __init__(self):
+ super(TaskParameters, self).__init__()
+
+ self.advertise_addr = None
+ self.listen_addr = None
+ self.remote_addrs = None
+ self.join_token = None
+
+ # Spec
+ self.snapshot_interval = None
+ self.task_history_retention_limit = None
+ self.keep_old_snapshots = None
+ self.log_entries_for_slow_followers = None
+ self.heartbeat_tick = None
+ self.election_tick = None
+ self.dispatcher_heartbeat_period = None
+ self.node_cert_expiry = None
+ self.name = None
+ self.labels = None
+ self.log_driver = None
+ self.signing_ca_cert = None
+ self.signing_ca_key = None
+ self.ca_force_rotate = None
+ self.autolock_managers = None
+ self.rotate_worker_token = None
+ self.rotate_manager_token = None
+ self.default_addr_pool = None
+ self.subnet_size = None
+
+ @staticmethod
+ def from_ansible_params(client):
+ result = TaskParameters()
+ for key, value in client.module.params.items():
+ if key in result.__dict__:
+ setattr(result, key, value)
+
+ result.update_parameters(client)
+ return result
+
+ def update_from_swarm_info(self, swarm_info):
+ spec = swarm_info['Spec']
+
+ ca_config = spec.get('CAConfig') or dict()
+ if self.node_cert_expiry is None:
+ self.node_cert_expiry = ca_config.get('NodeCertExpiry')
+ if self.ca_force_rotate is None:
+ self.ca_force_rotate = ca_config.get('ForceRotate')
+
+ dispatcher = spec.get('Dispatcher') or dict()
+ if self.dispatcher_heartbeat_period is None:
+ self.dispatcher_heartbeat_period = dispatcher.get('HeartbeatPeriod')
+
+ raft = spec.get('Raft') or dict()
+ if self.snapshot_interval is None:
+ self.snapshot_interval = raft.get('SnapshotInterval')
+ if self.keep_old_snapshots is None:
+ self.keep_old_snapshots = raft.get('KeepOldSnapshots')
+ if self.heartbeat_tick is None:
+ self.heartbeat_tick = raft.get('HeartbeatTick')
+ if self.log_entries_for_slow_followers is None:
+ self.log_entries_for_slow_followers = raft.get('LogEntriesForSlowFollowers')
+ if self.election_tick is None:
+ self.election_tick = raft.get('ElectionTick')
+
+ orchestration = spec.get('Orchestration') or dict()
+ if self.task_history_retention_limit is None:
+ self.task_history_retention_limit = orchestration.get('TaskHistoryRetentionLimit')
+
+ encryption_config = spec.get('EncryptionConfig') or dict()
+ if self.autolock_managers is None:
+ self.autolock_managers = encryption_config.get('AutoLockManagers')
+
+ if self.name is None:
+ self.name = spec['Name']
+
+ if self.labels is None:
+ self.labels = spec.get('Labels') or {}
+
+ if 'LogDriver' in spec['TaskDefaults']:
+ self.log_driver = spec['TaskDefaults']['LogDriver']
+
+ def update_parameters(self, client):
+ assign = dict(
+ snapshot_interval='snapshot_interval',
+ task_history_retention_limit='task_history_retention_limit',
+ keep_old_snapshots='keep_old_snapshots',
+ log_entries_for_slow_followers='log_entries_for_slow_followers',
+ heartbeat_tick='heartbeat_tick',
+ election_tick='election_tick',
+ dispatcher_heartbeat_period='dispatcher_heartbeat_period',
+ node_cert_expiry='node_cert_expiry',
+ name='name',
+ labels='labels',
+ signing_ca_cert='signing_ca_cert',
+ signing_ca_key='signing_ca_key',
+ ca_force_rotate='ca_force_rotate',
+ autolock_managers='autolock_managers',
+ log_driver='log_driver',
+ )
+ params = dict()
+ for dest, source in assign.items():
+ if not client.option_minimal_versions[source]['supported']:
+ continue
+ value = getattr(self, source)
+ if value is not None:
+ params[dest] = value
+ self.spec = client.create_swarm_spec(**params)
+
+ def compare_to_active(self, other, client, differences):
+ for k in self.__dict__:
+ if k in ('advertise_addr', 'listen_addr', 'remote_addrs', 'join_token',
+ 'rotate_worker_token', 'rotate_manager_token', 'spec',
+ 'default_addr_pool', 'subnet_size'):
+ continue
+ if not client.option_minimal_versions[k]['supported']:
+ continue
+ value = getattr(self, k)
+ if value is None:
+ continue
+ other_value = getattr(other, k)
+ if value != other_value:
+ differences.add(k, parameter=value, active=other_value)
+ if self.rotate_worker_token:
+ differences.add('rotate_worker_token', parameter=True, active=False)
+ if self.rotate_manager_token:
+ differences.add('rotate_manager_token', parameter=True, active=False)
+ return differences
+
+
+class SwarmManager(DockerBaseClass):
+
+ def __init__(self, client, results):
+
+ super(SwarmManager, self).__init__()
+
+ self.client = client
+ self.results = results
+ self.check_mode = self.client.check_mode
+ self.swarm_info = {}
+
+ self.state = client.module.params['state']
+ self.force = client.module.params['force']
+ self.node_id = client.module.params['node_id']
+
+ self.differences = DifferenceTracker()
+ self.parameters = TaskParameters.from_ansible_params(client)
+
+ self.created = False
+
+ def __call__(self):
+ choice_map = {
+ "present": self.init_swarm,
+ "join": self.join,
+ "absent": self.leave,
+ "remove": self.remove,
+ "inspect": self.inspect_swarm
+ }
+
+ if self.state == 'inspect':
+ self.client.module.deprecate(
+ "The 'inspect' state is deprecated, please use 'docker_swarm_info' to inspect swarm cluster",
+ version='2.12', collection_name='ansible.builtin')
+
+ choice_map.get(self.state)()
+
+ if self.client.module._diff or self.parameters.debug:
+ diff = dict()
+ diff['before'], diff['after'] = self.differences.get_before_after()
+ self.results['diff'] = diff
+
+ def inspect_swarm(self):
+ try:
+ data = self.client.inspect_swarm()
+ json_str = json.dumps(data, ensure_ascii=False)
+ self.swarm_info = json.loads(json_str)
+
+ self.results['changed'] = False
+ self.results['swarm_facts'] = self.swarm_info
+
+ unlock_key = self.get_unlock_key()
+ self.swarm_info.update(unlock_key)
+ except APIError:
+ return
+
+ def get_unlock_key(self):
+ default = {'UnlockKey': None}
+ if not self.has_swarm_lock_changed():
+ return default
+ try:
+ return self.client.get_unlock_key() or default
+ except APIError:
+ return default
+
+ def has_swarm_lock_changed(self):
+ return self.parameters.autolock_managers and (
+ self.created or self.differences.has_difference_for('autolock_managers')
+ )
+
+ def init_swarm(self):
+ if not self.force and self.client.check_if_swarm_manager():
+ self.__update_swarm()
+ return
+
+ if not self.check_mode:
+ init_arguments = {
+ 'advertise_addr': self.parameters.advertise_addr,
+ 'listen_addr': self.parameters.listen_addr,
+ 'force_new_cluster': self.force,
+ 'swarm_spec': self.parameters.spec,
+ }
+ if self.parameters.default_addr_pool is not None:
+ init_arguments['default_addr_pool'] = self.parameters.default_addr_pool
+ if self.parameters.subnet_size is not None:
+ init_arguments['subnet_size'] = self.parameters.subnet_size
+ try:
+ self.client.init_swarm(**init_arguments)
+ except APIError as exc:
+ self.client.fail("Can not create a new Swarm Cluster: %s" % to_native(exc))
+
+ if not self.client.check_if_swarm_manager():
+ if not self.check_mode:
+ self.client.fail("Swarm not created or other error!")
+
+ self.created = True
+ self.inspect_swarm()
+ self.results['actions'].append("New Swarm cluster created: %s" % (self.swarm_info.get('ID')))
+ self.differences.add('state', parameter='present', active='absent')
+ self.results['changed'] = True
+ self.results['swarm_facts'] = {
+ 'JoinTokens': self.swarm_info.get('JoinTokens'),
+ 'UnlockKey': self.swarm_info.get('UnlockKey')
+ }
+
+ def __update_swarm(self):
+ try:
+ self.inspect_swarm()
+ version = self.swarm_info['Version']['Index']
+ self.parameters.update_from_swarm_info(self.swarm_info)
+ old_parameters = TaskParameters()
+ old_parameters.update_from_swarm_info(self.swarm_info)
+ self.parameters.compare_to_active(old_parameters, self.client, self.differences)
+ if self.differences.empty:
+ self.results['actions'].append("No modification")
+ self.results['changed'] = False
+ return
+ update_parameters = TaskParameters.from_ansible_params(self.client)
+ update_parameters.update_parameters(self.client)
+ if not self.check_mode:
+ self.client.update_swarm(
+ version=version, swarm_spec=update_parameters.spec,
+ rotate_worker_token=self.parameters.rotate_worker_token,
+ rotate_manager_token=self.parameters.rotate_manager_token)
+ except APIError as exc:
+ self.client.fail("Can not update a Swarm Cluster: %s" % to_native(exc))
+ return
+
+ self.inspect_swarm()
+ self.results['actions'].append("Swarm cluster updated")
+ self.results['changed'] = True
+
+ def join(self):
+ if self.client.check_if_swarm_node():
+ self.results['actions'].append("This node is already part of a swarm.")
+ return
+ if not self.check_mode:
+ try:
+ self.client.join_swarm(
+ remote_addrs=self.parameters.remote_addrs, join_token=self.parameters.join_token,
+ listen_addr=self.parameters.listen_addr, advertise_addr=self.parameters.advertise_addr)
+ except APIError as exc:
+ self.client.fail("Can not join the Swarm Cluster: %s" % to_native(exc))
+ self.results['actions'].append("New node is added to swarm cluster")
+ self.differences.add('joined', parameter=True, active=False)
+ self.results['changed'] = True
+
+ def leave(self):
+ if not self.client.check_if_swarm_node():
+ self.results['actions'].append("This node is not part of a swarm.")
+ return
+ if not self.check_mode:
+ try:
+ self.client.leave_swarm(force=self.force)
+ except APIError as exc:
+ self.client.fail("This node can not leave the Swarm Cluster: %s" % to_native(exc))
+ self.results['actions'].append("Node has left the swarm cluster")
+ self.differences.add('joined', parameter='absent', active='present')
+ self.results['changed'] = True
+
+ def remove(self):
+ if not self.client.check_if_swarm_manager():
+ self.client.fail("This node is not a manager.")
+
+ try:
+ status_down = self.client.check_if_swarm_node_is_down(node_id=self.node_id, repeat_check=5)
+ except APIError:
+ return
+
+ if not status_down:
+ self.client.fail("Can not remove the node. The status node is ready and not down.")
+
+ if not self.check_mode:
+ try:
+ self.client.remove_node(node_id=self.node_id, force=self.force)
+ except APIError as exc:
+ self.client.fail("Can not remove the node from the Swarm Cluster: %s" % to_native(exc))
+ self.results['actions'].append("Node is removed from swarm cluster.")
+ self.differences.add('joined', parameter=False, active=True)
+ self.results['changed'] = True
+
+
+def _detect_remove_operation(client):
+ return client.module.params['state'] == 'remove'
+
+
+def main():
+ argument_spec = dict(
+ advertise_addr=dict(type='str'),
+ state=dict(type='str', default='present', choices=['present', 'join', 'absent', 'remove', 'inspect']),
+ force=dict(type='bool', default=False),
+ listen_addr=dict(type='str', default='0.0.0.0:2377'),
+ remote_addrs=dict(type='list', elements='str'),
+ join_token=dict(type='str'),
+ snapshot_interval=dict(type='int'),
+ task_history_retention_limit=dict(type='int'),
+ keep_old_snapshots=dict(type='int'),
+ log_entries_for_slow_followers=dict(type='int'),
+ heartbeat_tick=dict(type='int'),
+ election_tick=dict(type='int'),
+ dispatcher_heartbeat_period=dict(type='int'),
+ node_cert_expiry=dict(type='int'),
+ name=dict(type='str'),
+ labels=dict(type='dict'),
+ signing_ca_cert=dict(type='str'),
+ signing_ca_key=dict(type='str'),
+ ca_force_rotate=dict(type='int'),
+ autolock_managers=dict(type='bool'),
+ node_id=dict(type='str'),
+ rotate_worker_token=dict(type='bool', default=False),
+ rotate_manager_token=dict(type='bool', default=False),
+ default_addr_pool=dict(type='list', elements='str'),
+ subnet_size=dict(type='int'),
+ )
+
+ required_if = [
+ ('state', 'join', ['advertise_addr', 'remote_addrs', 'join_token']),
+ ('state', 'remove', ['node_id'])
+ ]
+
+ option_minimal_versions = dict(
+ labels=dict(docker_py_version='2.6.0', docker_api_version='1.32'),
+ signing_ca_cert=dict(docker_py_version='2.6.0', docker_api_version='1.30'),
+ signing_ca_key=dict(docker_py_version='2.6.0', docker_api_version='1.30'),
+ ca_force_rotate=dict(docker_py_version='2.6.0', docker_api_version='1.30'),
+ autolock_managers=dict(docker_py_version='2.6.0'),
+ log_driver=dict(docker_py_version='2.6.0'),
+ remove_operation=dict(
+ docker_py_version='2.4.0',
+ detect_usage=_detect_remove_operation,
+ usage_msg='remove swarm nodes'
+ ),
+ default_addr_pool=dict(docker_py_version='4.0.0', docker_api_version='1.39'),
+ subnet_size=dict(docker_py_version='4.0.0', docker_api_version='1.39'),
+ )
+
+ client = AnsibleDockerSwarmClient(
+ argument_spec=argument_spec,
+ supports_check_mode=True,
+ required_if=required_if,
+ min_docker_version='1.10.0',
+ min_docker_api_version='1.25',
+ option_minimal_versions=option_minimal_versions,
+ )
+
+ try:
+ results = dict(
+ changed=False,
+ result='',
+ actions=[]
+ )
+
+ SwarmManager(client, results)()
+ client.module.exit_json(**results)
+ except DockerException as e:
+ client.fail('An unexpected docker error occurred: {0}'.format(e), exception=traceback.format_exc())
+ except RequestException as e:
+ client.fail('An unexpected requests error occurred when docker-py tried to talk to the docker daemon: {0}'.format(e), exception=traceback.format_exc())
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/ec2.py b/test/support/integration/plugins/modules/ec2.py
new file mode 100644
index 00000000..952aa5a1
--- /dev/null
+++ b/test/support/integration/plugins/modules/ec2.py
@@ -0,0 +1,1766 @@
+#!/usr/bin/python
+# This file is part of Ansible
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+
+DOCUMENTATION = '''
+---
+module: ec2
+short_description: create, terminate, start or stop an instance in ec2
+description:
+ - Creates or terminates ec2 instances.
+ - >
+ Note: This module uses the older boto Python module to interact with the EC2 API.
+ M(ec2) will still receive bug fixes, but no new features.
+ Consider using the M(ec2_instance) module instead.
+ If M(ec2_instance) does not support a feature you need that is available in M(ec2), please
+ file a feature request.
+version_added: "0.9"
+options:
+ key_name:
+ description:
+ - Key pair to use on the instance.
+ - The SSH key must already exist in AWS in order to use this argument.
+ - Keys can be created / deleted using the M(ec2_key) module.
+ aliases: ['keypair']
+ type: str
+ id:
+ version_added: "1.1"
+ description:
+ - Identifier for this instance or set of instances, so that the module will be idempotent with respect to EC2 instances.
+ - This identifier is valid for at least 24 hours after the termination of the instance, and should not be reused for another call later on.
+ - For details, see the description of client token at U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Run_Instance_Idempotency.html).
+ type: str
+ group:
+ description:
+ - Security group (or list of groups) to use with the instance.
+ aliases: [ 'groups' ]
+ type: list
+ elements: str
+ group_id:
+ version_added: "1.1"
+ description:
+ - Security group id (or list of ids) to use with the instance.
+ type: list
+ elements: str
+ zone:
+ version_added: "1.2"
+ description:
+ - AWS availability zone in which to launch the instance.
+ aliases: [ 'aws_zone', 'ec2_zone' ]
+ type: str
+ instance_type:
+ description:
+ - Instance type to use for the instance, see U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-types.html).
+ - Required when creating a new instance.
+ type: str
+ aliases: ['type']
+ tenancy:
+ version_added: "1.9"
+ description:
+ - An instance with a tenancy of C(dedicated) runs on single-tenant hardware and can only be launched into a VPC.
+ - Note that to use dedicated tenancy you MUST specify a I(vpc_subnet_id) as well.
+ - Dedicated tenancy is not available for EC2 "micro" instances.
+ default: default
+ choices: [ "default", "dedicated" ]
+ type: str
+ spot_price:
+ version_added: "1.5"
+ description:
+ - Maximum spot price to bid. If not set, a regular on-demand instance is requested.
+ - A spot request is made with this maximum bid. When it is filled, the instance is started.
+ type: str
+ spot_type:
+ version_added: "2.0"
+ description:
+ - The type of spot request.
+ - After being interrupted a C(persistent) spot instance will be started once there is capacity to fill the request again.
+ default: "one-time"
+ choices: [ "one-time", "persistent" ]
+ type: str
+ image:
+ description:
+ - I(ami) ID to use for the instance.
+ - Required when I(state=present).
+ type: str
+ kernel:
+ description:
+ - Kernel eki to use for the instance.
+ type: str
+ ramdisk:
+ description:
+ - Ramdisk eri to use for the instance.
+ type: str
+ wait:
+ description:
+ - Wait for the instance to reach its desired state before returning.
+ - Does not wait for SSH, see the 'wait_for_connection' example for details.
+ type: bool
+ default: false
+ wait_timeout:
+ description:
+ - How long before wait gives up, in seconds.
+ default: 300
+ type: int
+ spot_wait_timeout:
+ version_added: "1.5"
+ description:
+ - How long to wait for the spot instance request to be fulfilled. Affects 'Request valid until' for setting spot request lifespan.
+ default: 600
+ type: int
+ count:
+ description:
+ - Number of instances to launch.
+ default: 1
+ type: int
+ monitoring:
+ version_added: "1.1"
+ description:
+ - Enable detailed monitoring (CloudWatch) for instance.
+ type: bool
+ default: false
+ user_data:
+ version_added: "0.9"
+ description:
+ - Opaque blob of data which is made available to the EC2 instance.
+ type: str
+ instance_tags:
+ version_added: "1.0"
+ description:
+ - A hash/dictionary of tags to add to the new instance or for starting/stopping instance by tag; '{"key":"value"}' and '{"key":"value","key":"value"}'.
+ type: dict
+ placement_group:
+ version_added: "1.3"
+ description:
+ - Placement group for the instance when using EC2 Clustered Compute.
+ type: str
+ vpc_subnet_id:
+ version_added: "1.1"
+ description:
+ - the subnet ID in which to launch the instance (VPC).
+ type: str
+ assign_public_ip:
+ version_added: "1.5"
+ description:
+ - When provisioning within vpc, assign a public IP address. Boto library must be 2.13.0+.
+ type: bool
+ private_ip:
+ version_added: "1.2"
+ description:
+ - The private ip address to assign the instance (from the vpc subnet).
+ type: str
+ instance_profile_name:
+ version_added: "1.3"
+ description:
+ - Name of the IAM instance profile (i.e. what the EC2 console refers to as an "IAM Role") to use. Boto library must be 2.5.0+.
+ type: str
+ instance_ids:
+ version_added: "1.3"
+ description:
+ - "list of instance ids, currently used for states: absent, running, stopped"
+ aliases: ['instance_id']
+ type: list
+ elements: str
+ source_dest_check:
+ version_added: "1.6"
+ description:
+ - Enable or Disable the Source/Destination checks (for NAT instances and Virtual Routers).
+ When initially creating an instance the EC2 API defaults this to C(True).
+ type: bool
+ termination_protection:
+ version_added: "2.0"
+ description:
+ - Enable or Disable the Termination Protection.
+ type: bool
+ default: false
+ instance_initiated_shutdown_behavior:
+ version_added: "2.2"
+ description:
+ - Set whether AWS will Stop or Terminate an instance on shutdown. This parameter is ignored when using instance-store.
+ images (which require termination on shutdown).
+ default: 'stop'
+ choices: [ "stop", "terminate" ]
+ type: str
+ state:
+ version_added: "1.3"
+ description:
+ - Create, terminate, start, stop or restart instances. The state 'restarted' was added in Ansible 2.2.
+ - When I(state=absent), I(instance_ids) is required.
+ - When I(state=running), I(state=stopped) or I(state=restarted) then either I(instance_ids) or I(instance_tags) is required.
+ default: 'present'
+ choices: ['absent', 'present', 'restarted', 'running', 'stopped']
+ type: str
+ volumes:
+ version_added: "1.5"
+ description:
+ - A list of hash/dictionaries of volumes to add to the new instance.
+ type: list
+ elements: dict
+ suboptions:
+ device_name:
+ type: str
+ required: true
+ description:
+ - A name for the device (For example C(/dev/sda)).
+ delete_on_termination:
+ type: bool
+ default: false
+ description:
+ - Whether the volume should be automatically deleted when the instance is terminated.
+ ephemeral:
+ type: str
+ description:
+ - Whether the volume should be ephemeral.
+ - Data on ephemeral volumes is lost when the instance is stopped.
+ - Mutually exclusive with the I(snapshot) parameter.
+ encrypted:
+ type: bool
+ default: false
+ description:
+ - Whether the volume should be encrypted using the 'aws/ebs' KMS CMK.
+ snapshot:
+ type: str
+ description:
+ - The ID of an EBS snapshot to copy when creating the volume.
+ - Mutually exclusive with the I(ephemeral) parameter.
+ volume_type:
+ type: str
+ description:
+ - The type of volume to create.
+ - See U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSVolumeTypes.html) for more information on the available volume types.
+ volume_size:
+ type: int
+ description:
+ - The size of the volume (in GiB).
+ iops:
+ type: int
+ description:
+ - The number of IOPS per second to provision for the volume.
+ - Required when I(volume_type=io1).
+ ebs_optimized:
+ version_added: "1.6"
+ description:
+ - Whether instance is using optimized EBS volumes, see U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSOptimized.html).
+ default: false
+ type: bool
+ exact_count:
+ version_added: "1.5"
+ description:
+ - An integer value which indicates how many instances that match the 'count_tag' parameter should be running.
+ Instances are either created or terminated based on this value.
+ type: int
+ count_tag:
+ version_added: "1.5"
+ description:
+ - Used with I(exact_count) to determine how many nodes based on a specific tag criteria should be running.
+ This can be expressed in multiple ways and is shown in the EXAMPLES section. For instance, one can request 25 servers
+ that are tagged with "class=webserver". The specified tag must already exist or be passed in as the I(instance_tags) option.
+ type: raw
+ network_interfaces:
+ version_added: "2.0"
+ description:
+ - A list of existing network interfaces to attach to the instance at launch. When specifying existing network interfaces,
+ none of the I(assign_public_ip), I(private_ip), I(vpc_subnet_id), I(group), or I(group_id) parameters may be used. (Those parameters are
+ for creating a new network interface at launch.)
+ aliases: ['network_interface']
+ type: list
+ elements: str
+ spot_launch_group:
+ version_added: "2.1"
+ description:
+ - Launch group for spot requests, see U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/how-spot-instances-work.html#spot-launch-group).
+ type: str
+author:
+ - "Tim Gerla (@tgerla)"
+ - "Lester Wade (@lwade)"
+ - "Seth Vidal (@skvidal)"
+extends_documentation_fragment:
+ - aws
+ - ec2
+'''
+
+EXAMPLES = '''
+# Note: These examples do not set authentication details, see the AWS Guide for details.
+
+# Basic provisioning example
+- ec2:
+ key_name: mykey
+ instance_type: t2.micro
+ image: ami-123456
+ wait: yes
+ group: webserver
+ count: 3
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+
+# Advanced example with tagging and CloudWatch
+- ec2:
+ key_name: mykey
+ group: databases
+ instance_type: t2.micro
+ image: ami-123456
+ wait: yes
+ wait_timeout: 500
+ count: 5
+ instance_tags:
+ db: postgres
+ monitoring: yes
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+
+# Single instance with additional IOPS volume from snapshot and volume delete on termination
+- ec2:
+ key_name: mykey
+ group: webserver
+ instance_type: c3.medium
+ image: ami-123456
+ wait: yes
+ wait_timeout: 500
+ volumes:
+ - device_name: /dev/sdb
+ snapshot: snap-abcdef12
+ volume_type: io1
+ iops: 1000
+ volume_size: 100
+ delete_on_termination: true
+ monitoring: yes
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+
+# Single instance with ssd gp2 root volume
+- ec2:
+ key_name: mykey
+ group: webserver
+ instance_type: c3.medium
+ image: ami-123456
+ wait: yes
+ wait_timeout: 500
+ volumes:
+ - device_name: /dev/xvda
+ volume_type: gp2
+ volume_size: 8
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+ count_tag:
+ Name: dbserver
+ exact_count: 1
+
+# Multiple groups example
+- ec2:
+ key_name: mykey
+ group: ['databases', 'internal-services', 'sshable', 'and-so-forth']
+ instance_type: m1.large
+ image: ami-6e649707
+ wait: yes
+ wait_timeout: 500
+ count: 5
+ instance_tags:
+ db: postgres
+ monitoring: yes
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+
+# Multiple instances with additional volume from snapshot
+- ec2:
+ key_name: mykey
+ group: webserver
+ instance_type: m1.large
+ image: ami-6e649707
+ wait: yes
+ wait_timeout: 500
+ count: 5
+ volumes:
+ - device_name: /dev/sdb
+ snapshot: snap-abcdef12
+ volume_size: 10
+ monitoring: yes
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+
+# Dedicated tenancy example
+- local_action:
+ module: ec2
+ assign_public_ip: yes
+ group_id: sg-1dc53f72
+ key_name: mykey
+ image: ami-6e649707
+ instance_type: m1.small
+ tenancy: dedicated
+ vpc_subnet_id: subnet-29e63245
+ wait: yes
+
+# Spot instance example
+- ec2:
+ spot_price: 0.24
+ spot_wait_timeout: 600
+ keypair: mykey
+ group_id: sg-1dc53f72
+ instance_type: m1.small
+ image: ami-6e649707
+ wait: yes
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+ spot_launch_group: report_generators
+ instance_initiated_shutdown_behavior: terminate
+
+# Examples using pre-existing network interfaces
+- ec2:
+ key_name: mykey
+ instance_type: t2.small
+ image: ami-f005ba11
+ network_interface: eni-deadbeef
+
+- ec2:
+ key_name: mykey
+ instance_type: t2.small
+ image: ami-f005ba11
+ network_interfaces: ['eni-deadbeef', 'eni-5ca1ab1e']
+
+# Launch instances, runs some tasks
+# and then terminate them
+
+- name: Create a sandbox instance
+ hosts: localhost
+ gather_facts: False
+ vars:
+ keypair: my_keypair
+ instance_type: m1.small
+ security_group: my_securitygroup
+ image: my_ami_id
+ region: us-east-1
+ tasks:
+ - name: Launch instance
+ ec2:
+ key_name: "{{ keypair }}"
+ group: "{{ security_group }}"
+ instance_type: "{{ instance_type }}"
+ image: "{{ image }}"
+ wait: true
+ region: "{{ region }}"
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+ register: ec2
+
+ - name: Add new instance to host group
+ add_host:
+ hostname: "{{ item.public_ip }}"
+ groupname: launched
+ loop: "{{ ec2.instances }}"
+
+ - name: Wait for SSH to come up
+ delegate_to: "{{ item.public_dns_name }}"
+ wait_for_connection:
+ delay: 60
+ timeout: 320
+ loop: "{{ ec2.instances }}"
+
+- name: Configure instance(s)
+ hosts: launched
+ become: True
+ gather_facts: True
+ roles:
+ - my_awesome_role
+ - my_awesome_test
+
+- name: Terminate instances
+ hosts: localhost
+ tasks:
+ - name: Terminate instances that were previously launched
+ ec2:
+ state: 'absent'
+ instance_ids: '{{ ec2.instance_ids }}'
+
+# Start a few existing instances, run some tasks
+# and stop the instances
+
+- name: Start sandbox instances
+ hosts: localhost
+ gather_facts: false
+ vars:
+ instance_ids:
+ - 'i-xxxxxx'
+ - 'i-xxxxxx'
+ - 'i-xxxxxx'
+ region: us-east-1
+ tasks:
+ - name: Start the sandbox instances
+ ec2:
+ instance_ids: '{{ instance_ids }}'
+ region: '{{ region }}'
+ state: running
+ wait: True
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+ roles:
+ - do_neat_stuff
+ - do_more_neat_stuff
+
+- name: Stop sandbox instances
+ hosts: localhost
+ gather_facts: false
+ vars:
+ instance_ids:
+ - 'i-xxxxxx'
+ - 'i-xxxxxx'
+ - 'i-xxxxxx'
+ region: us-east-1
+ tasks:
+ - name: Stop the sandbox instances
+ ec2:
+ instance_ids: '{{ instance_ids }}'
+ region: '{{ region }}'
+ state: stopped
+ wait: True
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+
+#
+# Start stopped instances specified by tag
+#
+- local_action:
+ module: ec2
+ instance_tags:
+ Name: ExtraPower
+ state: running
+
+#
+# Restart instances specified by tag
+#
+- local_action:
+ module: ec2
+ instance_tags:
+ Name: ExtraPower
+ state: restarted
+
+#
+# Enforce that 5 instances with a tag "foo" are running
+# (Highly recommended!)
+#
+
+- ec2:
+ key_name: mykey
+ instance_type: c1.medium
+ image: ami-40603AD1
+ wait: yes
+ group: webserver
+ instance_tags:
+ foo: bar
+ exact_count: 5
+ count_tag: foo
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+
+#
+# Enforce that 5 running instances named "database" with a "dbtype" of "postgres"
+#
+
+- ec2:
+ key_name: mykey
+ instance_type: c1.medium
+ image: ami-40603AD1
+ wait: yes
+ group: webserver
+ instance_tags:
+ Name: database
+ dbtype: postgres
+ exact_count: 5
+ count_tag:
+ Name: database
+ dbtype: postgres
+ vpc_subnet_id: subnet-29e63245
+ assign_public_ip: yes
+
+#
+# count_tag complex argument examples
+#
+
+ # instances with tag foo
+- ec2:
+ count_tag:
+ foo:
+
+ # instances with tag foo=bar
+- ec2:
+ count_tag:
+ foo: bar
+
+ # instances with tags foo=bar & baz
+- ec2:
+ count_tag:
+ foo: bar
+ baz:
+
+ # instances with tags foo & bar & baz=bang
+- ec2:
+ count_tag:
+ - foo
+ - bar
+ - baz: bang
+
+'''
+
+import time
+import datetime
+import traceback
+from ast import literal_eval
+from distutils.version import LooseVersion
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.ec2 import get_aws_connection_info, ec2_argument_spec, ec2_connect
+from ansible.module_utils.six import get_function_code, string_types
+from ansible.module_utils._text import to_bytes, to_text
+
+try:
+ import boto.ec2
+ from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping
+ from boto.exception import EC2ResponseError
+ from boto import connect_ec2_endpoint
+ from boto import connect_vpc
+ HAS_BOTO = True
+except ImportError:
+ HAS_BOTO = False
+
+
+def find_running_instances_by_count_tag(module, ec2, vpc, count_tag, zone=None):
+
+ # get reservations for instances that match tag(s) and are in the desired state
+ state = module.params.get('state')
+ if state not in ['running', 'stopped']:
+ state = None
+ reservations = get_reservations(module, ec2, vpc, tags=count_tag, state=state, zone=zone)
+
+ instances = []
+ for res in reservations:
+ if hasattr(res, 'instances'):
+ for inst in res.instances:
+ if inst.state == 'terminated' or inst.state == 'shutting-down':
+ continue
+ instances.append(inst)
+
+ return reservations, instances
+
+
+def _set_none_to_blank(dictionary):
+ result = dictionary
+ for k in result:
+ if isinstance(result[k], dict):
+ result[k] = _set_none_to_blank(result[k])
+ elif not result[k]:
+ result[k] = ""
+ return result
+
+
+def get_reservations(module, ec2, vpc, tags=None, state=None, zone=None):
+ # TODO: filters do not work with tags that have underscores
+ filters = dict()
+
+ vpc_subnet_id = module.params.get('vpc_subnet_id')
+ vpc_id = None
+ if vpc_subnet_id:
+ filters.update({"subnet-id": vpc_subnet_id})
+ if vpc:
+ vpc_id = vpc.get_all_subnets(subnet_ids=[vpc_subnet_id])[0].vpc_id
+
+ if vpc_id:
+ filters.update({"vpc-id": vpc_id})
+
+ if tags is not None:
+
+ if isinstance(tags, str):
+ try:
+ tags = literal_eval(tags)
+ except Exception:
+ pass
+
+ # if not a string type, convert and make sure it's a text string
+ if isinstance(tags, int):
+ tags = to_text(tags)
+
+ # if string, we only care that a tag of that name exists
+ if isinstance(tags, str):
+ filters.update({"tag-key": tags})
+
+ # if list, append each item to filters
+ if isinstance(tags, list):
+ for x in tags:
+ if isinstance(x, dict):
+ x = _set_none_to_blank(x)
+ filters.update(dict(("tag:" + tn, tv) for (tn, tv) in x.items()))
+ else:
+ filters.update({"tag-key": x})
+
+ # if dict, add the key and value to the filter
+ if isinstance(tags, dict):
+ tags = _set_none_to_blank(tags)
+ filters.update(dict(("tag:" + tn, tv) for (tn, tv) in tags.items()))
+
+ # lets check to see if the filters dict is empty, if so then stop
+ if not filters:
+ module.fail_json(msg="Filters based on tag is empty => tags: %s" % (tags))
+
+ if state:
+ # http://stackoverflow.com/questions/437511/what-are-the-valid-instancestates-for-the-amazon-ec2-api
+ filters.update({'instance-state-name': state})
+
+ if zone:
+ filters.update({'availability-zone': zone})
+
+ if module.params.get('id'):
+ filters['client-token'] = module.params['id']
+
+ results = ec2.get_all_instances(filters=filters)
+
+ return results
+
+
+def get_instance_info(inst):
+ """
+ Retrieves instance information from an instance
+ ID and returns it as a dictionary
+ """
+ instance_info = {'id': inst.id,
+ 'ami_launch_index': inst.ami_launch_index,
+ 'private_ip': inst.private_ip_address,
+ 'private_dns_name': inst.private_dns_name,
+ 'public_ip': inst.ip_address,
+ 'dns_name': inst.dns_name,
+ 'public_dns_name': inst.public_dns_name,
+ 'state_code': inst.state_code,
+ 'architecture': inst.architecture,
+ 'image_id': inst.image_id,
+ 'key_name': inst.key_name,
+ 'placement': inst.placement,
+ 'region': inst.placement[:-1],
+ 'kernel': inst.kernel,
+ 'ramdisk': inst.ramdisk,
+ 'launch_time': inst.launch_time,
+ 'instance_type': inst.instance_type,
+ 'root_device_type': inst.root_device_type,
+ 'root_device_name': inst.root_device_name,
+ 'state': inst.state,
+ 'hypervisor': inst.hypervisor,
+ 'tags': inst.tags,
+ 'groups': dict((group.id, group.name) for group in inst.groups),
+ }
+ try:
+ instance_info['virtualization_type'] = getattr(inst, 'virtualization_type')
+ except AttributeError:
+ instance_info['virtualization_type'] = None
+
+ try:
+ instance_info['ebs_optimized'] = getattr(inst, 'ebs_optimized')
+ except AttributeError:
+ instance_info['ebs_optimized'] = False
+
+ try:
+ bdm_dict = {}
+ bdm = getattr(inst, 'block_device_mapping')
+ for device_name in bdm.keys():
+ bdm_dict[device_name] = {
+ 'status': bdm[device_name].status,
+ 'volume_id': bdm[device_name].volume_id,
+ 'delete_on_termination': bdm[device_name].delete_on_termination
+ }
+ instance_info['block_device_mapping'] = bdm_dict
+ except AttributeError:
+ instance_info['block_device_mapping'] = False
+
+ try:
+ instance_info['tenancy'] = getattr(inst, 'placement_tenancy')
+ except AttributeError:
+ instance_info['tenancy'] = 'default'
+
+ return instance_info
+
+
+def boto_supports_associate_public_ip_address(ec2):
+ """
+ Check if Boto library has associate_public_ip_address in the NetworkInterfaceSpecification
+ class. Added in Boto 2.13.0
+
+ ec2: authenticated ec2 connection object
+
+ Returns:
+ True if Boto library accepts associate_public_ip_address argument, else false
+ """
+
+ try:
+ network_interface = boto.ec2.networkinterface.NetworkInterfaceSpecification()
+ getattr(network_interface, "associate_public_ip_address")
+ return True
+ except AttributeError:
+ return False
+
+
+def boto_supports_profile_name_arg(ec2):
+ """
+ Check if Boto library has instance_profile_name argument. instance_profile_name has been added in Boto 2.5.0
+
+ ec2: authenticated ec2 connection object
+
+ Returns:
+ True if Boto library accept instance_profile_name argument, else false
+ """
+ run_instances_method = getattr(ec2, 'run_instances')
+ return 'instance_profile_name' in get_function_code(run_instances_method).co_varnames
+
+
+def boto_supports_volume_encryption():
+ """
+ Check if Boto library supports encryption of EBS volumes (added in 2.29.0)
+
+ Returns:
+ True if boto library has the named param as an argument on the request_spot_instances method, else False
+ """
+ return hasattr(boto, 'Version') and LooseVersion(boto.Version) >= LooseVersion('2.29.0')
+
+
+def create_block_device(module, ec2, volume):
+ # Not aware of a way to determine this programatically
+ # http://aws.amazon.com/about-aws/whats-new/2013/10/09/ebs-provisioned-iops-maximum-iops-gb-ratio-increased-to-30-1/
+ MAX_IOPS_TO_SIZE_RATIO = 30
+
+ volume_type = volume.get('volume_type')
+
+ if 'snapshot' not in volume and 'ephemeral' not in volume:
+ if 'volume_size' not in volume:
+ module.fail_json(msg='Size must be specified when creating a new volume or modifying the root volume')
+ if 'snapshot' in volume:
+ if volume_type == 'io1' and 'iops' not in volume:
+ module.fail_json(msg='io1 volumes must have an iops value set')
+ if 'iops' in volume:
+ snapshot = ec2.get_all_snapshots(snapshot_ids=[volume['snapshot']])[0]
+ size = volume.get('volume_size', snapshot.volume_size)
+ if int(volume['iops']) > MAX_IOPS_TO_SIZE_RATIO * size:
+ module.fail_json(msg='IOPS must be at most %d times greater than size' % MAX_IOPS_TO_SIZE_RATIO)
+ if 'ephemeral' in volume:
+ if 'snapshot' in volume:
+ module.fail_json(msg='Cannot set both ephemeral and snapshot')
+ if boto_supports_volume_encryption():
+ return BlockDeviceType(snapshot_id=volume.get('snapshot'),
+ ephemeral_name=volume.get('ephemeral'),
+ size=volume.get('volume_size'),
+ volume_type=volume_type,
+ delete_on_termination=volume.get('delete_on_termination', False),
+ iops=volume.get('iops'),
+ encrypted=volume.get('encrypted', None))
+ else:
+ return BlockDeviceType(snapshot_id=volume.get('snapshot'),
+ ephemeral_name=volume.get('ephemeral'),
+ size=volume.get('volume_size'),
+ volume_type=volume_type,
+ delete_on_termination=volume.get('delete_on_termination', False),
+ iops=volume.get('iops'))
+
+
+def boto_supports_param_in_spot_request(ec2, param):
+ """
+ Check if Boto library has a <param> in its request_spot_instances() method. For example, the placement_group parameter wasn't added until 2.3.0.
+
+ ec2: authenticated ec2 connection object
+
+ Returns:
+ True if boto library has the named param as an argument on the request_spot_instances method, else False
+ """
+ method = getattr(ec2, 'request_spot_instances')
+ return param in get_function_code(method).co_varnames
+
+
+def await_spot_requests(module, ec2, spot_requests, count):
+ """
+ Wait for a group of spot requests to be fulfilled, or fail.
+
+ module: Ansible module object
+ ec2: authenticated ec2 connection object
+ spot_requests: boto.ec2.spotinstancerequest.SpotInstanceRequest object returned by ec2.request_spot_instances
+ count: Total number of instances to be created by the spot requests
+
+ Returns:
+ list of instance ID's created by the spot request(s)
+ """
+ spot_wait_timeout = int(module.params.get('spot_wait_timeout'))
+ wait_complete = time.time() + spot_wait_timeout
+
+ spot_req_inst_ids = dict()
+ while time.time() < wait_complete:
+ reqs = ec2.get_all_spot_instance_requests()
+ for sirb in spot_requests:
+ if sirb.id in spot_req_inst_ids:
+ continue
+ for sir in reqs:
+ if sir.id != sirb.id:
+ continue # this is not our spot instance
+ if sir.instance_id is not None:
+ spot_req_inst_ids[sirb.id] = sir.instance_id
+ elif sir.state == 'open':
+ continue # still waiting, nothing to do here
+ elif sir.state == 'active':
+ continue # Instance is created already, nothing to do here
+ elif sir.state == 'failed':
+ module.fail_json(msg="Spot instance request %s failed with status %s and fault %s:%s" % (
+ sir.id, sir.status.code, sir.fault.code, sir.fault.message))
+ elif sir.state == 'cancelled':
+ module.fail_json(msg="Spot instance request %s was cancelled before it could be fulfilled." % sir.id)
+ elif sir.state == 'closed':
+ # instance is terminating or marked for termination
+ # this may be intentional on the part of the operator,
+ # or it may have been terminated by AWS due to capacity,
+ # price, or group constraints in this case, we'll fail
+ # the module if the reason for the state is anything
+ # other than termination by user. Codes are documented at
+ # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-bid-status.html
+ if sir.status.code == 'instance-terminated-by-user':
+ # do nothing, since the user likely did this on purpose
+ pass
+ else:
+ spot_msg = "Spot instance request %s was closed by AWS with the status %s and fault %s:%s"
+ module.fail_json(msg=spot_msg % (sir.id, sir.status.code, sir.fault.code, sir.fault.message))
+
+ if len(spot_req_inst_ids) < count:
+ time.sleep(5)
+ else:
+ return list(spot_req_inst_ids.values())
+ module.fail_json(msg="wait for spot requests timeout on %s" % time.asctime())
+
+
+def enforce_count(module, ec2, vpc):
+
+ exact_count = module.params.get('exact_count')
+ count_tag = module.params.get('count_tag')
+ zone = module.params.get('zone')
+
+ # fail here if the exact count was specified without filtering
+ # on a tag, as this may lead to a undesired removal of instances
+ if exact_count and count_tag is None:
+ module.fail_json(msg="you must use the 'count_tag' option with exact_count")
+
+ reservations, instances = find_running_instances_by_count_tag(module, ec2, vpc, count_tag, zone)
+
+ changed = None
+ checkmode = False
+ instance_dict_array = []
+ changed_instance_ids = None
+
+ if len(instances) == exact_count:
+ changed = False
+ elif len(instances) < exact_count:
+ changed = True
+ to_create = exact_count - len(instances)
+ if not checkmode:
+ (instance_dict_array, changed_instance_ids, changed) \
+ = create_instances(module, ec2, vpc, override_count=to_create)
+
+ for inst in instance_dict_array:
+ instances.append(inst)
+ elif len(instances) > exact_count:
+ changed = True
+ to_remove = len(instances) - exact_count
+ if not checkmode:
+ all_instance_ids = sorted([x.id for x in instances])
+ remove_ids = all_instance_ids[0:to_remove]
+
+ instances = [x for x in instances if x.id not in remove_ids]
+
+ (changed, instance_dict_array, changed_instance_ids) \
+ = terminate_instances(module, ec2, remove_ids)
+ terminated_list = []
+ for inst in instance_dict_array:
+ inst['state'] = "terminated"
+ terminated_list.append(inst)
+ instance_dict_array = terminated_list
+
+ # ensure all instances are dictionaries
+ all_instances = []
+ for inst in instances:
+
+ if not isinstance(inst, dict):
+ warn_if_public_ip_assignment_changed(module, inst)
+ inst = get_instance_info(inst)
+ all_instances.append(inst)
+
+ return (all_instances, instance_dict_array, changed_instance_ids, changed)
+
+
+def create_instances(module, ec2, vpc, override_count=None):
+ """
+ Creates new instances
+
+ module : AnsibleModule object
+ ec2: authenticated ec2 connection object
+
+ Returns:
+ A list of dictionaries with instance information
+ about the instances that were launched
+ """
+
+ key_name = module.params.get('key_name')
+ id = module.params.get('id')
+ group_name = module.params.get('group')
+ group_id = module.params.get('group_id')
+ zone = module.params.get('zone')
+ instance_type = module.params.get('instance_type')
+ tenancy = module.params.get('tenancy')
+ spot_price = module.params.get('spot_price')
+ spot_type = module.params.get('spot_type')
+ image = module.params.get('image')
+ if override_count:
+ count = override_count
+ else:
+ count = module.params.get('count')
+ monitoring = module.params.get('monitoring')
+ kernel = module.params.get('kernel')
+ ramdisk = module.params.get('ramdisk')
+ wait = module.params.get('wait')
+ wait_timeout = int(module.params.get('wait_timeout'))
+ spot_wait_timeout = int(module.params.get('spot_wait_timeout'))
+ placement_group = module.params.get('placement_group')
+ user_data = module.params.get('user_data')
+ instance_tags = module.params.get('instance_tags')
+ vpc_subnet_id = module.params.get('vpc_subnet_id')
+ assign_public_ip = module.boolean(module.params.get('assign_public_ip'))
+ private_ip = module.params.get('private_ip')
+ instance_profile_name = module.params.get('instance_profile_name')
+ volumes = module.params.get('volumes')
+ ebs_optimized = module.params.get('ebs_optimized')
+ exact_count = module.params.get('exact_count')
+ count_tag = module.params.get('count_tag')
+ source_dest_check = module.boolean(module.params.get('source_dest_check'))
+ termination_protection = module.boolean(module.params.get('termination_protection'))
+ network_interfaces = module.params.get('network_interfaces')
+ spot_launch_group = module.params.get('spot_launch_group')
+ instance_initiated_shutdown_behavior = module.params.get('instance_initiated_shutdown_behavior')
+
+ vpc_id = None
+ if vpc_subnet_id:
+ if not vpc:
+ module.fail_json(msg="region must be specified")
+ else:
+ vpc_id = vpc.get_all_subnets(subnet_ids=[vpc_subnet_id])[0].vpc_id
+ else:
+ vpc_id = None
+
+ try:
+ # Here we try to lookup the group id from the security group name - if group is set.
+ if group_name:
+ if vpc_id:
+ grp_details = ec2.get_all_security_groups(filters={'vpc_id': vpc_id})
+ else:
+ grp_details = ec2.get_all_security_groups()
+ if isinstance(group_name, string_types):
+ group_name = [group_name]
+ unmatched = set(group_name).difference(str(grp.name) for grp in grp_details)
+ if len(unmatched) > 0:
+ module.fail_json(msg="The following group names are not valid: %s" % ', '.join(unmatched))
+ group_id = [str(grp.id) for grp in grp_details if str(grp.name) in group_name]
+ # Now we try to lookup the group id testing if group exists.
+ elif group_id:
+ # wrap the group_id in a list if it's not one already
+ if isinstance(group_id, string_types):
+ group_id = [group_id]
+ grp_details = ec2.get_all_security_groups(group_ids=group_id)
+ group_name = [grp_item.name for grp_item in grp_details]
+ except boto.exception.NoAuthHandlerFound as e:
+ module.fail_json(msg=str(e))
+
+ # Lookup any instances that much our run id.
+
+ running_instances = []
+ count_remaining = int(count)
+
+ if id is not None:
+ filter_dict = {'client-token': id, 'instance-state-name': 'running'}
+ previous_reservations = ec2.get_all_instances(None, filter_dict)
+ for res in previous_reservations:
+ for prev_instance in res.instances:
+ running_instances.append(prev_instance)
+ count_remaining = count_remaining - len(running_instances)
+
+ # Both min_count and max_count equal count parameter. This means the launch request is explicit (we want count, or fail) in how many instances we want.
+
+ if count_remaining == 0:
+ changed = False
+ else:
+ changed = True
+ try:
+ params = {'image_id': image,
+ 'key_name': key_name,
+ 'monitoring_enabled': monitoring,
+ 'placement': zone,
+ 'instance_type': instance_type,
+ 'kernel_id': kernel,
+ 'ramdisk_id': ramdisk}
+ if user_data is not None:
+ params['user_data'] = to_bytes(user_data, errors='surrogate_or_strict')
+
+ if ebs_optimized:
+ params['ebs_optimized'] = ebs_optimized
+
+ # 'tenancy' always has a default value, but it is not a valid parameter for spot instance request
+ if not spot_price:
+ params['tenancy'] = tenancy
+
+ if boto_supports_profile_name_arg(ec2):
+ params['instance_profile_name'] = instance_profile_name
+ else:
+ if instance_profile_name is not None:
+ module.fail_json(
+ msg="instance_profile_name parameter requires Boto version 2.5.0 or higher")
+
+ if assign_public_ip is not None:
+ if not boto_supports_associate_public_ip_address(ec2):
+ module.fail_json(
+ msg="assign_public_ip parameter requires Boto version 2.13.0 or higher.")
+ elif not vpc_subnet_id:
+ module.fail_json(
+ msg="assign_public_ip only available with vpc_subnet_id")
+
+ else:
+ if private_ip:
+ interface = boto.ec2.networkinterface.NetworkInterfaceSpecification(
+ subnet_id=vpc_subnet_id,
+ private_ip_address=private_ip,
+ groups=group_id,
+ associate_public_ip_address=assign_public_ip)
+ else:
+ interface = boto.ec2.networkinterface.NetworkInterfaceSpecification(
+ subnet_id=vpc_subnet_id,
+ groups=group_id,
+ associate_public_ip_address=assign_public_ip)
+ interfaces = boto.ec2.networkinterface.NetworkInterfaceCollection(interface)
+ params['network_interfaces'] = interfaces
+ else:
+ if network_interfaces:
+ if isinstance(network_interfaces, string_types):
+ network_interfaces = [network_interfaces]
+ interfaces = []
+ for i, network_interface_id in enumerate(network_interfaces):
+ interface = boto.ec2.networkinterface.NetworkInterfaceSpecification(
+ network_interface_id=network_interface_id,
+ device_index=i)
+ interfaces.append(interface)
+ params['network_interfaces'] = \
+ boto.ec2.networkinterface.NetworkInterfaceCollection(*interfaces)
+ else:
+ params['subnet_id'] = vpc_subnet_id
+ if vpc_subnet_id:
+ params['security_group_ids'] = group_id
+ else:
+ params['security_groups'] = group_name
+
+ if volumes:
+ bdm = BlockDeviceMapping()
+ for volume in volumes:
+ if 'device_name' not in volume:
+ module.fail_json(msg='Device name must be set for volume')
+ # Minimum volume size is 1GiB. We'll use volume size explicitly set to 0
+ # to be a signal not to create this volume
+ if 'volume_size' not in volume or int(volume['volume_size']) > 0:
+ bdm[volume['device_name']] = create_block_device(module, ec2, volume)
+
+ params['block_device_map'] = bdm
+
+ # check to see if we're using spot pricing first before starting instances
+ if not spot_price:
+ if assign_public_ip is not None and private_ip:
+ params.update(
+ dict(
+ min_count=count_remaining,
+ max_count=count_remaining,
+ client_token=id,
+ placement_group=placement_group,
+ )
+ )
+ else:
+ params.update(
+ dict(
+ min_count=count_remaining,
+ max_count=count_remaining,
+ client_token=id,
+ placement_group=placement_group,
+ private_ip_address=private_ip,
+ )
+ )
+
+ # For ordinary (not spot) instances, we can select 'stop'
+ # (the default) or 'terminate' here.
+ params['instance_initiated_shutdown_behavior'] = instance_initiated_shutdown_behavior or 'stop'
+
+ try:
+ res = ec2.run_instances(**params)
+ except boto.exception.EC2ResponseError as e:
+ if (params['instance_initiated_shutdown_behavior'] != 'terminate' and
+ "InvalidParameterCombination" == e.error_code):
+ params['instance_initiated_shutdown_behavior'] = 'terminate'
+ res = ec2.run_instances(**params)
+ else:
+ raise
+
+ instids = [i.id for i in res.instances]
+ while True:
+ try:
+ ec2.get_all_instances(instids)
+ break
+ except boto.exception.EC2ResponseError as e:
+ if "<Code>InvalidInstanceID.NotFound</Code>" in str(e):
+ # there's a race between start and get an instance
+ continue
+ else:
+ module.fail_json(msg=str(e))
+
+ # The instances returned through ec2.run_instances above can be in
+ # terminated state due to idempotency. See commit 7f11c3d for a complete
+ # explanation.
+ terminated_instances = [
+ str(instance.id) for instance in res.instances if instance.state == 'terminated'
+ ]
+ if terminated_instances:
+ module.fail_json(msg="Instances with id(s) %s " % terminated_instances +
+ "were created previously but have since been terminated - " +
+ "use a (possibly different) 'instanceid' parameter")
+
+ else:
+ if private_ip:
+ module.fail_json(
+ msg='private_ip only available with on-demand (non-spot) instances')
+ if boto_supports_param_in_spot_request(ec2, 'placement_group'):
+ params['placement_group'] = placement_group
+ elif placement_group:
+ module.fail_json(
+ msg="placement_group parameter requires Boto version 2.3.0 or higher.")
+
+ # You can't tell spot instances to 'stop'; they will always be
+ # 'terminate'd. For convenience, we'll ignore the latter value.
+ if instance_initiated_shutdown_behavior and instance_initiated_shutdown_behavior != 'terminate':
+ module.fail_json(
+ msg="instance_initiated_shutdown_behavior=stop is not supported for spot instances.")
+
+ if spot_launch_group and isinstance(spot_launch_group, string_types):
+ params['launch_group'] = spot_launch_group
+
+ params.update(dict(
+ count=count_remaining,
+ type=spot_type,
+ ))
+
+ # Set spot ValidUntil
+ # ValidUntil -> (timestamp). The end date of the request, in
+ # UTC format (for example, YYYY -MM -DD T*HH* :MM :SS Z).
+ utc_valid_until = (
+ datetime.datetime.utcnow()
+ + datetime.timedelta(seconds=spot_wait_timeout))
+ params['valid_until'] = utc_valid_until.strftime('%Y-%m-%dT%H:%M:%S.000Z')
+
+ res = ec2.request_spot_instances(spot_price, **params)
+
+ # Now we have to do the intermediate waiting
+ if wait:
+ instids = await_spot_requests(module, ec2, res, count)
+ else:
+ instids = []
+ except boto.exception.BotoServerError as e:
+ module.fail_json(msg="Instance creation failed => %s: %s" % (e.error_code, e.error_message))
+
+ # wait here until the instances are up
+ num_running = 0
+ wait_timeout = time.time() + wait_timeout
+ res_list = ()
+ while wait_timeout > time.time() and num_running < len(instids):
+ try:
+ res_list = ec2.get_all_instances(instids)
+ except boto.exception.BotoServerError as e:
+ if e.error_code == 'InvalidInstanceID.NotFound':
+ time.sleep(1)
+ continue
+ else:
+ raise
+
+ num_running = 0
+ for res in res_list:
+ num_running += len([i for i in res.instances if i.state == 'running'])
+ if len(res_list) <= 0:
+ # got a bad response of some sort, possibly due to
+ # stale/cached data. Wait a second and then try again
+ time.sleep(1)
+ continue
+ if wait and num_running < len(instids):
+ time.sleep(5)
+ else:
+ break
+
+ if wait and wait_timeout <= time.time():
+ # waiting took too long
+ module.fail_json(msg="wait for instances running timeout on %s" % time.asctime())
+
+ # We do this after the loop ends so that we end up with one list
+ for res in res_list:
+ running_instances.extend(res.instances)
+
+ # Enabled by default by AWS
+ if source_dest_check is False:
+ for inst in res.instances:
+ inst.modify_attribute('sourceDestCheck', False)
+
+ # Disabled by default by AWS
+ if termination_protection is True:
+ for inst in res.instances:
+ inst.modify_attribute('disableApiTermination', True)
+
+ # Leave this as late as possible to try and avoid InvalidInstanceID.NotFound
+ if instance_tags and instids:
+ try:
+ ec2.create_tags(instids, instance_tags)
+ except boto.exception.EC2ResponseError as e:
+ module.fail_json(msg="Instance tagging failed => %s: %s" % (e.error_code, e.error_message))
+
+ instance_dict_array = []
+ created_instance_ids = []
+ for inst in running_instances:
+ inst.update()
+ d = get_instance_info(inst)
+ created_instance_ids.append(inst.id)
+ instance_dict_array.append(d)
+
+ return (instance_dict_array, created_instance_ids, changed)
+
+
+def terminate_instances(module, ec2, instance_ids):
+ """
+ Terminates a list of instances
+
+ module: Ansible module object
+ ec2: authenticated ec2 connection object
+ termination_list: a list of instances to terminate in the form of
+ [ {id: <inst-id>}, ..]
+
+ Returns a dictionary of instance information
+ about the instances terminated.
+
+ If the instance to be terminated is running
+ "changed" will be set to False.
+
+ """
+
+ # Whether to wait for termination to complete before returning
+ wait = module.params.get('wait')
+ wait_timeout = int(module.params.get('wait_timeout'))
+
+ changed = False
+ instance_dict_array = []
+
+ if not isinstance(instance_ids, list) or len(instance_ids) < 1:
+ module.fail_json(msg='instance_ids should be a list of instances, aborting')
+
+ terminated_instance_ids = []
+ for res in ec2.get_all_instances(instance_ids):
+ for inst in res.instances:
+ if inst.state == 'running' or inst.state == 'stopped':
+ terminated_instance_ids.append(inst.id)
+ instance_dict_array.append(get_instance_info(inst))
+ try:
+ ec2.terminate_instances([inst.id])
+ except EC2ResponseError as e:
+ module.fail_json(msg='Unable to terminate instance {0}, error: {1}'.format(inst.id, e))
+ changed = True
+
+ # wait here until the instances are 'terminated'
+ if wait:
+ num_terminated = 0
+ wait_timeout = time.time() + wait_timeout
+ while wait_timeout > time.time() and num_terminated < len(terminated_instance_ids):
+ response = ec2.get_all_instances(instance_ids=terminated_instance_ids,
+ filters={'instance-state-name': 'terminated'})
+ try:
+ num_terminated = sum([len(res.instances) for res in response])
+ except Exception as e:
+ # got a bad response of some sort, possibly due to
+ # stale/cached data. Wait a second and then try again
+ time.sleep(1)
+ continue
+
+ if num_terminated < len(terminated_instance_ids):
+ time.sleep(5)
+
+ # waiting took too long
+ if wait_timeout < time.time() and num_terminated < len(terminated_instance_ids):
+ module.fail_json(msg="wait for instance termination timeout on %s" % time.asctime())
+ # Lets get the current state of the instances after terminating - issue600
+ instance_dict_array = []
+ for res in ec2.get_all_instances(instance_ids=terminated_instance_ids, filters={'instance-state-name': 'terminated'}):
+ for inst in res.instances:
+ instance_dict_array.append(get_instance_info(inst))
+
+ return (changed, instance_dict_array, terminated_instance_ids)
+
+
+def startstop_instances(module, ec2, instance_ids, state, instance_tags):
+ """
+ Starts or stops a list of existing instances
+
+ module: Ansible module object
+ ec2: authenticated ec2 connection object
+ instance_ids: The list of instances to start in the form of
+ [ {id: <inst-id>}, ..]
+ instance_tags: A dict of tag keys and values in the form of
+ {key: value, ... }
+ state: Intended state ("running" or "stopped")
+
+ Returns a dictionary of instance information
+ about the instances started/stopped.
+
+ If the instance was not able to change state,
+ "changed" will be set to False.
+
+ Note that if instance_ids and instance_tags are both non-empty,
+ this method will process the intersection of the two
+ """
+
+ wait = module.params.get('wait')
+ wait_timeout = int(module.params.get('wait_timeout'))
+ group_id = module.params.get('group_id')
+ group_name = module.params.get('group')
+ changed = False
+ instance_dict_array = []
+
+ if not isinstance(instance_ids, list) or len(instance_ids) < 1:
+ # Fail unless the user defined instance tags
+ if not instance_tags:
+ module.fail_json(msg='instance_ids should be a list of instances, aborting')
+
+ # To make an EC2 tag filter, we need to prepend 'tag:' to each key.
+ # An empty filter does no filtering, so it's safe to pass it to the
+ # get_all_instances method even if the user did not specify instance_tags
+ filters = {}
+ if instance_tags:
+ for key, value in instance_tags.items():
+ filters["tag:" + key] = value
+
+ if module.params.get('id'):
+ filters['client-token'] = module.params['id']
+ # Check that our instances are not in the state we want to take
+
+ # Check (and eventually change) instances attributes and instances state
+ existing_instances_array = []
+ for res in ec2.get_all_instances(instance_ids, filters=filters):
+ for inst in res.instances:
+
+ warn_if_public_ip_assignment_changed(module, inst)
+
+ changed = (check_source_dest_attr(module, inst, ec2) or
+ check_termination_protection(module, inst) or changed)
+
+ # Check security groups and if we're using ec2-vpc; ec2-classic security groups may not be modified
+ if inst.vpc_id and group_name:
+ grp_details = ec2.get_all_security_groups(filters={'vpc_id': inst.vpc_id})
+ if isinstance(group_name, string_types):
+ group_name = [group_name]
+ unmatched = set(group_name) - set(to_text(grp.name) for grp in grp_details)
+ if unmatched:
+ module.fail_json(msg="The following group names are not valid: %s" % ', '.join(unmatched))
+ group_ids = [to_text(grp.id) for grp in grp_details if to_text(grp.name) in group_name]
+ elif inst.vpc_id and group_id:
+ if isinstance(group_id, string_types):
+ group_id = [group_id]
+ grp_details = ec2.get_all_security_groups(group_ids=group_id)
+ group_ids = [grp_item.id for grp_item in grp_details]
+ if inst.vpc_id and (group_name or group_id):
+ if set(sg.id for sg in inst.groups) != set(group_ids):
+ changed = inst.modify_attribute('groupSet', group_ids)
+
+ # Check instance state
+ if inst.state != state:
+ instance_dict_array.append(get_instance_info(inst))
+ try:
+ if state == 'running':
+ inst.start()
+ else:
+ inst.stop()
+ except EC2ResponseError as e:
+ module.fail_json(msg='Unable to change state for instance {0}, error: {1}'.format(inst.id, e))
+ changed = True
+ existing_instances_array.append(inst.id)
+
+ instance_ids = list(set(existing_instances_array + (instance_ids or [])))
+ # Wait for all the instances to finish starting or stopping
+ wait_timeout = time.time() + wait_timeout
+ while wait and wait_timeout > time.time():
+ instance_dict_array = []
+ matched_instances = []
+ for res in ec2.get_all_instances(instance_ids):
+ for i in res.instances:
+ if i.state == state:
+ instance_dict_array.append(get_instance_info(i))
+ matched_instances.append(i)
+ if len(matched_instances) < len(instance_ids):
+ time.sleep(5)
+ else:
+ break
+
+ if wait and wait_timeout <= time.time():
+ # waiting took too long
+ module.fail_json(msg="wait for instances running timeout on %s" % time.asctime())
+
+ return (changed, instance_dict_array, instance_ids)
+
+
+def restart_instances(module, ec2, instance_ids, state, instance_tags):
+ """
+ Restarts a list of existing instances
+
+ module: Ansible module object
+ ec2: authenticated ec2 connection object
+ instance_ids: The list of instances to start in the form of
+ [ {id: <inst-id>}, ..]
+ instance_tags: A dict of tag keys and values in the form of
+ {key: value, ... }
+ state: Intended state ("restarted")
+
+ Returns a dictionary of instance information
+ about the instances.
+
+ If the instance was not able to change state,
+ "changed" will be set to False.
+
+ Wait will not apply here as this is a OS level operation.
+
+ Note that if instance_ids and instance_tags are both non-empty,
+ this method will process the intersection of the two.
+ """
+
+ changed = False
+ instance_dict_array = []
+
+ if not isinstance(instance_ids, list) or len(instance_ids) < 1:
+ # Fail unless the user defined instance tags
+ if not instance_tags:
+ module.fail_json(msg='instance_ids should be a list of instances, aborting')
+
+ # To make an EC2 tag filter, we need to prepend 'tag:' to each key.
+ # An empty filter does no filtering, so it's safe to pass it to the
+ # get_all_instances method even if the user did not specify instance_tags
+ filters = {}
+ if instance_tags:
+ for key, value in instance_tags.items():
+ filters["tag:" + key] = value
+ if module.params.get('id'):
+ filters['client-token'] = module.params['id']
+
+ # Check that our instances are not in the state we want to take
+
+ # Check (and eventually change) instances attributes and instances state
+ for res in ec2.get_all_instances(instance_ids, filters=filters):
+ for inst in res.instances:
+
+ warn_if_public_ip_assignment_changed(module, inst)
+
+ changed = (check_source_dest_attr(module, inst, ec2) or
+ check_termination_protection(module, inst) or changed)
+
+ # Check instance state
+ if inst.state != state:
+ instance_dict_array.append(get_instance_info(inst))
+ try:
+ inst.reboot()
+ except EC2ResponseError as e:
+ module.fail_json(msg='Unable to change state for instance {0}, error: {1}'.format(inst.id, e))
+ changed = True
+
+ return (changed, instance_dict_array, instance_ids)
+
+
+def check_termination_protection(module, inst):
+ """
+ Check the instance disableApiTermination attribute.
+
+ module: Ansible module object
+ inst: EC2 instance object
+
+ returns: True if state changed None otherwise
+ """
+
+ termination_protection = module.params.get('termination_protection')
+
+ if (inst.get_attribute('disableApiTermination')['disableApiTermination'] != termination_protection and termination_protection is not None):
+ inst.modify_attribute('disableApiTermination', termination_protection)
+ return True
+
+
+def check_source_dest_attr(module, inst, ec2):
+ """
+ Check the instance sourceDestCheck attribute.
+
+ module: Ansible module object
+ inst: EC2 instance object
+
+ returns: True if state changed None otherwise
+ """
+
+ source_dest_check = module.params.get('source_dest_check')
+
+ if source_dest_check is not None:
+ try:
+ if inst.vpc_id is not None and inst.get_attribute('sourceDestCheck')['sourceDestCheck'] != source_dest_check:
+ inst.modify_attribute('sourceDestCheck', source_dest_check)
+ return True
+ except boto.exception.EC2ResponseError as exc:
+ # instances with more than one Elastic Network Interface will
+ # fail, because they have the sourceDestCheck attribute defined
+ # per-interface
+ if exc.code == 'InvalidInstanceID':
+ for interface in inst.interfaces:
+ if interface.source_dest_check != source_dest_check:
+ ec2.modify_network_interface_attribute(interface.id, "sourceDestCheck", source_dest_check)
+ return True
+ else:
+ module.fail_json(msg='Failed to handle source_dest_check state for instance {0}, error: {1}'.format(inst.id, exc),
+ exception=traceback.format_exc())
+
+
+def warn_if_public_ip_assignment_changed(module, instance):
+ # This is a non-modifiable attribute.
+ assign_public_ip = module.params.get('assign_public_ip')
+
+ # Check that public ip assignment is the same and warn if not
+ public_dns_name = getattr(instance, 'public_dns_name', None)
+ if (assign_public_ip or public_dns_name) and (not public_dns_name or assign_public_ip is False):
+ module.warn("Unable to modify public ip assignment to {0} for instance {1}. "
+ "Whether or not to assign a public IP is determined during instance creation.".format(assign_public_ip, instance.id))
+
+
+def main():
+ argument_spec = ec2_argument_spec()
+ argument_spec.update(
+ dict(
+ key_name=dict(aliases=['keypair']),
+ id=dict(),
+ group=dict(type='list', aliases=['groups']),
+ group_id=dict(type='list'),
+ zone=dict(aliases=['aws_zone', 'ec2_zone']),
+ instance_type=dict(aliases=['type']),
+ spot_price=dict(),
+ spot_type=dict(default='one-time', choices=["one-time", "persistent"]),
+ spot_launch_group=dict(),
+ image=dict(),
+ kernel=dict(),
+ count=dict(type='int', default='1'),
+ monitoring=dict(type='bool', default=False),
+ ramdisk=dict(),
+ wait=dict(type='bool', default=False),
+ wait_timeout=dict(type='int', default=300),
+ spot_wait_timeout=dict(type='int', default=600),
+ placement_group=dict(),
+ user_data=dict(),
+ instance_tags=dict(type='dict'),
+ vpc_subnet_id=dict(),
+ assign_public_ip=dict(type='bool'),
+ private_ip=dict(),
+ instance_profile_name=dict(),
+ instance_ids=dict(type='list', aliases=['instance_id']),
+ source_dest_check=dict(type='bool', default=None),
+ termination_protection=dict(type='bool', default=None),
+ state=dict(default='present', choices=['present', 'absent', 'running', 'restarted', 'stopped']),
+ instance_initiated_shutdown_behavior=dict(default='stop', choices=['stop', 'terminate']),
+ exact_count=dict(type='int', default=None),
+ count_tag=dict(type='raw'),
+ volumes=dict(type='list'),
+ ebs_optimized=dict(type='bool', default=False),
+ tenancy=dict(default='default', choices=['default', 'dedicated']),
+ network_interfaces=dict(type='list', aliases=['network_interface'])
+ )
+ )
+
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ mutually_exclusive=[
+ # Can be uncommented when we finish the deprecation cycle.
+ # ['group', 'group_id'],
+ ['exact_count', 'count'],
+ ['exact_count', 'state'],
+ ['exact_count', 'instance_ids'],
+ ['network_interfaces', 'assign_public_ip'],
+ ['network_interfaces', 'group'],
+ ['network_interfaces', 'group_id'],
+ ['network_interfaces', 'private_ip'],
+ ['network_interfaces', 'vpc_subnet_id'],
+ ],
+ )
+
+ if module.params.get('group') and module.params.get('group_id'):
+ module.deprecate(
+ msg='Support for passing both group and group_id has been deprecated. '
+ 'Currently group_id is ignored, in future passing both will result in an error',
+ version='2.14', collection_name='ansible.builtin')
+
+ if not HAS_BOTO:
+ module.fail_json(msg='boto required for this module')
+
+ try:
+ region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module)
+ if module.params.get('region') or not module.params.get('ec2_url'):
+ ec2 = ec2_connect(module)
+ elif module.params.get('ec2_url'):
+ ec2 = connect_ec2_endpoint(ec2_url, **aws_connect_kwargs)
+
+ if 'region' not in aws_connect_kwargs:
+ aws_connect_kwargs['region'] = ec2.region
+
+ vpc = connect_vpc(**aws_connect_kwargs)
+ except boto.exception.NoAuthHandlerFound as e:
+ module.fail_json(msg="Failed to get connection: %s" % e.message, exception=traceback.format_exc())
+
+ tagged_instances = []
+
+ state = module.params['state']
+
+ if state == 'absent':
+ instance_ids = module.params['instance_ids']
+ if not instance_ids:
+ module.fail_json(msg='instance_ids list is required for absent state')
+
+ (changed, instance_dict_array, new_instance_ids) = terminate_instances(module, ec2, instance_ids)
+
+ elif state in ('running', 'stopped'):
+ instance_ids = module.params.get('instance_ids')
+ instance_tags = module.params.get('instance_tags')
+ if not (isinstance(instance_ids, list) or isinstance(instance_tags, dict)):
+ module.fail_json(msg='running list needs to be a list of instances or set of tags to run: %s' % instance_ids)
+
+ (changed, instance_dict_array, new_instance_ids) = startstop_instances(module, ec2, instance_ids, state, instance_tags)
+
+ elif state in ('restarted'):
+ instance_ids = module.params.get('instance_ids')
+ instance_tags = module.params.get('instance_tags')
+ if not (isinstance(instance_ids, list) or isinstance(instance_tags, dict)):
+ module.fail_json(msg='running list needs to be a list of instances or set of tags to run: %s' % instance_ids)
+
+ (changed, instance_dict_array, new_instance_ids) = restart_instances(module, ec2, instance_ids, state, instance_tags)
+
+ elif state == 'present':
+ # Changed is always set to true when provisioning new instances
+ if not module.params.get('image'):
+ module.fail_json(msg='image parameter is required for new instance')
+
+ if module.params.get('exact_count') is None:
+ (instance_dict_array, new_instance_ids, changed) = create_instances(module, ec2, vpc)
+ else:
+ (tagged_instances, instance_dict_array, new_instance_ids, changed) = enforce_count(module, ec2, vpc)
+
+ # Always return instances in the same order
+ if new_instance_ids:
+ new_instance_ids.sort()
+ if instance_dict_array:
+ instance_dict_array.sort(key=lambda x: x['id'])
+ if tagged_instances:
+ tagged_instances.sort(key=lambda x: x['id'])
+
+ module.exit_json(changed=changed, instance_ids=new_instance_ids, instances=instance_dict_array, tagged_instances=tagged_instances)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/ec2_ami_info.py b/test/support/integration/plugins/modules/ec2_ami_info.py
new file mode 100644
index 00000000..53c2374d
--- /dev/null
+++ b/test/support/integration/plugins/modules/ec2_ami_info.py
@@ -0,0 +1,282 @@
+#!/usr/bin/python
+# Copyright: Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: ec2_ami_info
+version_added: '2.5'
+short_description: Gather information about ec2 AMIs
+description:
+ - Gather information about ec2 AMIs
+ - This module was called C(ec2_ami_facts) before Ansible 2.9. The usage did not change.
+author:
+ - Prasad Katti (@prasadkatti)
+requirements: [ boto3 ]
+options:
+ image_ids:
+ description: One or more image IDs.
+ aliases: [image_id]
+ type: list
+ elements: str
+ filters:
+ description:
+ - A dict of filters to apply. Each dict item consists of a filter key and a filter value.
+ - See U(https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeImages.html) for possible filters.
+ - Filter names and values are case sensitive.
+ type: dict
+ owners:
+ description:
+ - Filter the images by the owner. Valid options are an AWS account ID, self,
+ or an AWS owner alias ( amazon | aws-marketplace | microsoft ).
+ aliases: [owner]
+ type: list
+ elements: str
+ executable_users:
+ description:
+ - Filter images by users with explicit launch permissions. Valid options are an AWS account ID, self, or all (public AMIs).
+ aliases: [executable_user]
+ type: list
+ elements: str
+ describe_image_attributes:
+ description:
+ - Describe attributes (like launchPermission) of the images found.
+ default: no
+ type: bool
+
+extends_documentation_fragment:
+ - aws
+ - ec2
+'''
+
+EXAMPLES = '''
+# Note: These examples do not set authentication details, see the AWS Guide for details.
+
+- name: gather information about an AMI using ami-id
+ ec2_ami_info:
+ image_ids: ami-5b488823
+
+- name: gather information about all AMIs with tag key Name and value webapp
+ ec2_ami_info:
+ filters:
+ "tag:Name": webapp
+
+- name: gather information about an AMI with 'AMI Name' equal to foobar
+ ec2_ami_info:
+ filters:
+ name: foobar
+
+- name: gather information about Ubuntu 17.04 AMIs published by Canonical (099720109477)
+ ec2_ami_info:
+ owners: 099720109477
+ filters:
+ name: "ubuntu/images/ubuntu-zesty-17.04-*"
+'''
+
+RETURN = '''
+images:
+ description: A list of images.
+ returned: always
+ type: list
+ elements: dict
+ contains:
+ architecture:
+ description: The architecture of the image.
+ returned: always
+ type: str
+ sample: x86_64
+ block_device_mappings:
+ description: Any block device mapping entries.
+ returned: always
+ type: list
+ elements: dict
+ contains:
+ device_name:
+ description: The device name exposed to the instance.
+ returned: always
+ type: str
+ sample: /dev/sda1
+ ebs:
+ description: EBS volumes
+ returned: always
+ type: complex
+ creation_date:
+ description: The date and time the image was created.
+ returned: always
+ type: str
+ sample: '2017-10-16T19:22:13.000Z'
+ description:
+ description: The description of the AMI.
+ returned: always
+ type: str
+ sample: ''
+ ena_support:
+ description: Whether enhanced networking with ENA is enabled.
+ returned: always
+ type: bool
+ sample: true
+ hypervisor:
+ description: The hypervisor type of the image.
+ returned: always
+ type: str
+ sample: xen
+ image_id:
+ description: The ID of the AMI.
+ returned: always
+ type: str
+ sample: ami-5b466623
+ image_location:
+ description: The location of the AMI.
+ returned: always
+ type: str
+ sample: 408466080000/Webapp
+ image_type:
+ description: The type of image.
+ returned: always
+ type: str
+ sample: machine
+ launch_permissions:
+ description: A List of AWS accounts may launch the AMI.
+ returned: When image is owned by calling account and I(describe_image_attributes) is yes.
+ type: list
+ elements: dict
+ contains:
+ group:
+ description: A value of 'all' means the AMI is public.
+ type: str
+ user_id:
+ description: An AWS account ID with permissions to launch the AMI.
+ type: str
+ sample: [{"group": "all"}, {"user_id": "408466080000"}]
+ name:
+ description: The name of the AMI that was provided during image creation.
+ returned: always
+ type: str
+ sample: Webapp
+ owner_id:
+ description: The AWS account ID of the image owner.
+ returned: always
+ type: str
+ sample: '408466080000'
+ public:
+ description: Whether the image has public launch permissions.
+ returned: always
+ type: bool
+ sample: true
+ root_device_name:
+ description: The device name of the root device.
+ returned: always
+ type: str
+ sample: /dev/sda1
+ root_device_type:
+ description: The type of root device used by the AMI.
+ returned: always
+ type: str
+ sample: ebs
+ sriov_net_support:
+ description: Whether enhanced networking is enabled.
+ returned: always
+ type: str
+ sample: simple
+ state:
+ description: The current state of the AMI.
+ returned: always
+ type: str
+ sample: available
+ tags:
+ description: Any tags assigned to the image.
+ returned: always
+ type: dict
+ virtualization_type:
+ description: The type of virtualization of the AMI.
+ returned: always
+ type: str
+ sample: hvm
+'''
+
+try:
+ from botocore.exceptions import ClientError, BotoCoreError
+except ImportError:
+ pass # caught by AnsibleAWSModule
+
+from ansible.module_utils.aws.core import AnsibleAWSModule
+from ansible.module_utils.ec2 import ansible_dict_to_boto3_filter_list, camel_dict_to_snake_dict, boto3_tag_list_to_ansible_dict
+
+
+def list_ec2_images(ec2_client, module):
+
+ image_ids = module.params.get("image_ids")
+ owners = module.params.get("owners")
+ executable_users = module.params.get("executable_users")
+ filters = module.params.get("filters")
+ owner_param = []
+
+ # describe_images is *very* slow if you pass the `Owners`
+ # param (unless it's self), for some reason.
+ # Converting the owners to filters and removing from the
+ # owners param greatly speeds things up.
+ # Implementation based on aioue's suggestion in #24886
+ for owner in owners:
+ if owner.isdigit():
+ if 'owner-id' not in filters:
+ filters['owner-id'] = list()
+ filters['owner-id'].append(owner)
+ elif owner == 'self':
+ # self not a valid owner-alias filter (https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeImages.html)
+ owner_param.append(owner)
+ else:
+ if 'owner-alias' not in filters:
+ filters['owner-alias'] = list()
+ filters['owner-alias'].append(owner)
+
+ filters = ansible_dict_to_boto3_filter_list(filters)
+
+ try:
+ images = ec2_client.describe_images(ImageIds=image_ids, Filters=filters, Owners=owner_param, ExecutableUsers=executable_users)
+ images = [camel_dict_to_snake_dict(image) for image in images["Images"]]
+ except (ClientError, BotoCoreError) as err:
+ module.fail_json_aws(err, msg="error describing images")
+ for image in images:
+ try:
+ image['tags'] = boto3_tag_list_to_ansible_dict(image.get('tags', []))
+ if module.params.get("describe_image_attributes"):
+ launch_permissions = ec2_client.describe_image_attribute(Attribute='launchPermission', ImageId=image['image_id'])['LaunchPermissions']
+ image['launch_permissions'] = [camel_dict_to_snake_dict(perm) for perm in launch_permissions]
+ except (ClientError, BotoCoreError) as err:
+ # describing launch permissions of images owned by others is not permitted, but shouldn't cause failures
+ pass
+
+ images.sort(key=lambda e: e.get('creation_date', '')) # it may be possible that creation_date does not always exist
+ module.exit_json(images=images)
+
+
+def main():
+
+ argument_spec = dict(
+ image_ids=dict(default=[], type='list', aliases=['image_id']),
+ filters=dict(default={}, type='dict'),
+ owners=dict(default=[], type='list', aliases=['owner']),
+ executable_users=dict(default=[], type='list', aliases=['executable_user']),
+ describe_image_attributes=dict(default=False, type='bool')
+ )
+
+ module = AnsibleAWSModule(argument_spec=argument_spec, supports_check_mode=True)
+ if module._module._name == 'ec2_ami_facts':
+ module._module.deprecate("The 'ec2_ami_facts' module has been renamed to 'ec2_ami_info'",
+ version='2.13', collection_name='ansible.builtin')
+
+ ec2_client = module.client('ec2')
+
+ list_ec2_images(ec2_client, module)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/ec2_group.py b/test/support/integration/plugins/modules/ec2_group.py
new file mode 100644
index 00000000..bc416f66
--- /dev/null
+++ b/test/support/integration/plugins/modules/ec2_group.py
@@ -0,0 +1,1345 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+# This file is part of Ansible
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = '''
+---
+module: ec2_group
+author: "Andrew de Quincey (@adq)"
+version_added: "1.3"
+requirements: [ boto3 ]
+short_description: maintain an ec2 VPC security group.
+description:
+ - Maintains ec2 security groups. This module has a dependency on python-boto >= 2.5.
+options:
+ name:
+ description:
+ - Name of the security group.
+ - One of and only one of I(name) or I(group_id) is required.
+ - Required if I(state=present).
+ required: false
+ type: str
+ group_id:
+ description:
+ - Id of group to delete (works only with absent).
+ - One of and only one of I(name) or I(group_id) is required.
+ required: false
+ version_added: "2.4"
+ type: str
+ description:
+ description:
+ - Description of the security group. Required when C(state) is C(present).
+ required: false
+ type: str
+ vpc_id:
+ description:
+ - ID of the VPC to create the group in.
+ required: false
+ type: str
+ rules:
+ description:
+ - List of firewall inbound rules to enforce in this group (see example). If none are supplied,
+ no inbound rules will be enabled. Rules list may include its own name in `group_name`.
+ This allows idempotent loopback additions (e.g. allow group to access itself).
+ Rule sources list support was added in version 2.4. This allows to define multiple sources per
+ source type as well as multiple source types per rule. Prior to 2.4 an individual source is allowed.
+ In version 2.5 support for rule descriptions was added.
+ required: false
+ type: list
+ elements: dict
+ suboptions:
+ cidr_ip:
+ type: str
+ description:
+ - The IPv4 CIDR range traffic is coming from.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ cidr_ipv6:
+ type: str
+ description:
+ - The IPv6 CIDR range traffic is coming from.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ ip_prefix:
+ type: str
+ description:
+ - The IP Prefix U(https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-prefix-lists.html)
+ that traffic is coming from.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ group_id:
+ type: str
+ description:
+ - The ID of the Security Group that traffic is coming from.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ group_name:
+ type: str
+ description:
+ - Name of the Security Group that traffic is coming from.
+ - If the Security Group doesn't exist a new Security Group will be
+ created with I(group_desc) as the description.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ group_desc:
+ type: str
+ description:
+ - If the I(group_name) is set and the Security Group doesn't exist a new Security Group will be
+ created with I(group_desc) as the description.
+ proto:
+ type: str
+ description:
+ - The IP protocol name (C(tcp), C(udp), C(icmp), C(icmpv6)) or number (U(https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers))
+ from_port:
+ type: int
+ description: The start of the range of ports that traffic is coming from. A value of C(-1) indicates all ports.
+ to_port:
+ type: int
+ description: The end of the range of ports that traffic is coming from. A value of C(-1) indicates all ports.
+ rule_desc:
+ type: str
+ description: A description for the rule.
+ rules_egress:
+ description:
+ - List of firewall outbound rules to enforce in this group (see example). If none are supplied,
+ a default all-out rule is assumed. If an empty list is supplied, no outbound rules will be enabled.
+ Rule Egress sources list support was added in version 2.4. In version 2.5 support for rule descriptions
+ was added.
+ required: false
+ version_added: "1.6"
+ type: list
+ elements: dict
+ suboptions:
+ cidr_ip:
+ type: str
+ description:
+ - The IPv4 CIDR range traffic is going to.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ cidr_ipv6:
+ type: str
+ description:
+ - The IPv6 CIDR range traffic is going to.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ ip_prefix:
+ type: str
+ description:
+ - The IP Prefix U(https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-prefix-lists.html)
+ that traffic is going to.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ group_id:
+ type: str
+ description:
+ - The ID of the Security Group that traffic is going to.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ group_name:
+ type: str
+ description:
+ - Name of the Security Group that traffic is going to.
+ - If the Security Group doesn't exist a new Security Group will be
+ created with I(group_desc) as the description.
+ - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id)
+ and I(group_name).
+ group_desc:
+ type: str
+ description:
+ - If the I(group_name) is set and the Security Group doesn't exist a new Security Group will be
+ created with I(group_desc) as the description.
+ proto:
+ type: str
+ description:
+ - The IP protocol name (C(tcp), C(udp), C(icmp), C(icmpv6)) or number (U(https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers))
+ from_port:
+ type: int
+ description: The start of the range of ports that traffic is going to. A value of C(-1) indicates all ports.
+ to_port:
+ type: int
+ description: The end of the range of ports that traffic is going to. A value of C(-1) indicates all ports.
+ rule_desc:
+ type: str
+ description: A description for the rule.
+ state:
+ version_added: "1.4"
+ description:
+ - Create or delete a security group.
+ required: false
+ default: 'present'
+ choices: [ "present", "absent" ]
+ aliases: []
+ type: str
+ purge_rules:
+ version_added: "1.8"
+ description:
+ - Purge existing rules on security group that are not found in rules.
+ required: false
+ default: 'true'
+ aliases: []
+ type: bool
+ purge_rules_egress:
+ version_added: "1.8"
+ description:
+ - Purge existing rules_egress on security group that are not found in rules_egress.
+ required: false
+ default: 'true'
+ aliases: []
+ type: bool
+ tags:
+ version_added: "2.4"
+ description:
+ - A dictionary of one or more tags to assign to the security group.
+ required: false
+ type: dict
+ aliases: ['resource_tags']
+ purge_tags:
+ version_added: "2.4"
+ description:
+ - If yes, existing tags will be purged from the resource to match exactly what is defined by I(tags) parameter. If the I(tags) parameter is not set then
+ tags will not be modified.
+ required: false
+ default: yes
+ type: bool
+
+extends_documentation_fragment:
+ - aws
+ - ec2
+
+notes:
+ - If a rule declares a group_name and that group doesn't exist, it will be
+ automatically created. In that case, group_desc should be provided as well.
+ The module will refuse to create a depended-on group without a description.
+ - Preview diff mode support is added in version 2.7.
+'''
+
+EXAMPLES = '''
+- name: example using security group rule descriptions
+ ec2_group:
+ name: "{{ name }}"
+ description: sg with rule descriptions
+ vpc_id: vpc-xxxxxxxx
+ profile: "{{ aws_profile }}"
+ region: us-east-1
+ rules:
+ - proto: tcp
+ ports:
+ - 80
+ cidr_ip: 0.0.0.0/0
+ rule_desc: allow all on port 80
+
+- name: example ec2 group
+ ec2_group:
+ name: example
+ description: an example EC2 group
+ vpc_id: 12345
+ region: eu-west-1
+ aws_secret_key: SECRET
+ aws_access_key: ACCESS
+ rules:
+ - proto: tcp
+ from_port: 80
+ to_port: 80
+ cidr_ip: 0.0.0.0/0
+ - proto: tcp
+ from_port: 22
+ to_port: 22
+ cidr_ip: 10.0.0.0/8
+ - proto: tcp
+ from_port: 443
+ to_port: 443
+ # this should only be needed for EC2 Classic security group rules
+ # because in a VPC an ELB will use a user-account security group
+ group_id: amazon-elb/sg-87654321/amazon-elb-sg
+ - proto: tcp
+ from_port: 3306
+ to_port: 3306
+ group_id: 123412341234/sg-87654321/exact-name-of-sg
+ - proto: udp
+ from_port: 10050
+ to_port: 10050
+ cidr_ip: 10.0.0.0/8
+ - proto: udp
+ from_port: 10051
+ to_port: 10051
+ group_id: sg-12345678
+ - proto: icmp
+ from_port: 8 # icmp type, -1 = any type
+ to_port: -1 # icmp subtype, -1 = any subtype
+ cidr_ip: 10.0.0.0/8
+ - proto: all
+ # the containing group name may be specified here
+ group_name: example
+ - proto: all
+ # in the 'proto' attribute, if you specify -1, all, or a protocol number other than tcp, udp, icmp, or 58 (ICMPv6),
+ # traffic on all ports is allowed, regardless of any ports you specify
+ from_port: 10050 # this value is ignored
+ to_port: 10050 # this value is ignored
+ cidr_ip: 10.0.0.0/8
+
+ rules_egress:
+ - proto: tcp
+ from_port: 80
+ to_port: 80
+ cidr_ip: 0.0.0.0/0
+ cidr_ipv6: 64:ff9b::/96
+ group_name: example-other
+ # description to use if example-other needs to be created
+ group_desc: other example EC2 group
+
+- name: example2 ec2 group
+ ec2_group:
+ name: example2
+ description: an example2 EC2 group
+ vpc_id: 12345
+ region: eu-west-1
+ rules:
+ # 'ports' rule keyword was introduced in version 2.4. It accepts a single port value or a list of values including ranges (from_port-to_port).
+ - proto: tcp
+ ports: 22
+ group_name: example-vpn
+ - proto: tcp
+ ports:
+ - 80
+ - 443
+ - 8080-8099
+ cidr_ip: 0.0.0.0/0
+ # Rule sources list support was added in version 2.4. This allows to define multiple sources per source type as well as multiple source types per rule.
+ - proto: tcp
+ ports:
+ - 6379
+ - 26379
+ group_name:
+ - example-vpn
+ - example-redis
+ - proto: tcp
+ ports: 5665
+ group_name: example-vpn
+ cidr_ip:
+ - 172.16.1.0/24
+ - 172.16.17.0/24
+ cidr_ipv6:
+ - 2607:F8B0::/32
+ - 64:ff9b::/96
+ group_id:
+ - sg-edcd9784
+ diff: True
+
+- name: "Delete group by its id"
+ ec2_group:
+ region: eu-west-1
+ group_id: sg-33b4ee5b
+ state: absent
+'''
+
+RETURN = '''
+group_name:
+ description: Security group name
+ sample: My Security Group
+ type: str
+ returned: on create/update
+group_id:
+ description: Security group id
+ sample: sg-abcd1234
+ type: str
+ returned: on create/update
+description:
+ description: Description of security group
+ sample: My Security Group
+ type: str
+ returned: on create/update
+tags:
+ description: Tags associated with the security group
+ sample:
+ Name: My Security Group
+ Purpose: protecting stuff
+ type: dict
+ returned: on create/update
+vpc_id:
+ description: ID of VPC to which the security group belongs
+ sample: vpc-abcd1234
+ type: str
+ returned: on create/update
+ip_permissions:
+ description: Inbound rules associated with the security group.
+ sample:
+ - from_port: 8182
+ ip_protocol: tcp
+ ip_ranges:
+ - cidr_ip: "1.1.1.1/32"
+ ipv6_ranges: []
+ prefix_list_ids: []
+ to_port: 8182
+ user_id_group_pairs: []
+ type: list
+ returned: on create/update
+ip_permissions_egress:
+ description: Outbound rules associated with the security group.
+ sample:
+ - ip_protocol: -1
+ ip_ranges:
+ - cidr_ip: "0.0.0.0/0"
+ ipv6_ranges: []
+ prefix_list_ids: []
+ user_id_group_pairs: []
+ type: list
+ returned: on create/update
+owner_id:
+ description: AWS Account ID of the security group
+ sample: 123456789012
+ type: int
+ returned: on create/update
+'''
+
+import json
+import re
+import itertools
+from copy import deepcopy
+from time import sleep
+from collections import namedtuple
+from ansible.module_utils.aws.core import AnsibleAWSModule, is_boto3_error_code
+from ansible.module_utils.aws.iam import get_aws_account_id
+from ansible.module_utils.aws.waiters import get_waiter
+from ansible.module_utils.ec2 import AWSRetry, camel_dict_to_snake_dict, compare_aws_tags
+from ansible.module_utils.ec2 import ansible_dict_to_boto3_filter_list, boto3_tag_list_to_ansible_dict, ansible_dict_to_boto3_tag_list
+from ansible.module_utils.common.network import to_ipv6_subnet, to_subnet
+from ansible.module_utils.compat.ipaddress import ip_network, IPv6Network
+from ansible.module_utils._text import to_text
+from ansible.module_utils.six import string_types
+
+try:
+ from botocore.exceptions import BotoCoreError, ClientError
+except ImportError:
+ pass # caught by AnsibleAWSModule
+
+
+Rule = namedtuple('Rule', ['port_range', 'protocol', 'target', 'target_type', 'description'])
+valid_targets = set(['ipv4', 'ipv6', 'group', 'ip_prefix'])
+current_account_id = None
+
+
+def rule_cmp(a, b):
+ """Compare rules without descriptions"""
+ for prop in ['port_range', 'protocol', 'target', 'target_type']:
+ if prop == 'port_range' and to_text(a.protocol) == to_text(b.protocol):
+ # equal protocols can interchange `(-1, -1)` and `(None, None)`
+ if a.port_range in ((None, None), (-1, -1)) and b.port_range in ((None, None), (-1, -1)):
+ continue
+ elif getattr(a, prop) != getattr(b, prop):
+ return False
+ elif getattr(a, prop) != getattr(b, prop):
+ return False
+ return True
+
+
+def rules_to_permissions(rules):
+ return [to_permission(rule) for rule in rules]
+
+
+def to_permission(rule):
+ # take a Rule, output the serialized grant
+ perm = {
+ 'IpProtocol': rule.protocol,
+ }
+ perm['FromPort'], perm['ToPort'] = rule.port_range
+ if rule.target_type == 'ipv4':
+ perm['IpRanges'] = [{
+ 'CidrIp': rule.target,
+ }]
+ if rule.description:
+ perm['IpRanges'][0]['Description'] = rule.description
+ elif rule.target_type == 'ipv6':
+ perm['Ipv6Ranges'] = [{
+ 'CidrIpv6': rule.target,
+ }]
+ if rule.description:
+ perm['Ipv6Ranges'][0]['Description'] = rule.description
+ elif rule.target_type == 'group':
+ if isinstance(rule.target, tuple):
+ pair = {}
+ if rule.target[0]:
+ pair['UserId'] = rule.target[0]
+ # group_id/group_name are mutually exclusive - give group_id more precedence as it is more specific
+ if rule.target[1]:
+ pair['GroupId'] = rule.target[1]
+ elif rule.target[2]:
+ pair['GroupName'] = rule.target[2]
+ perm['UserIdGroupPairs'] = [pair]
+ else:
+ perm['UserIdGroupPairs'] = [{
+ 'GroupId': rule.target
+ }]
+ if rule.description:
+ perm['UserIdGroupPairs'][0]['Description'] = rule.description
+ elif rule.target_type == 'ip_prefix':
+ perm['PrefixListIds'] = [{
+ 'PrefixListId': rule.target,
+ }]
+ if rule.description:
+ perm['PrefixListIds'][0]['Description'] = rule.description
+ elif rule.target_type not in valid_targets:
+ raise ValueError('Invalid target type for rule {0}'.format(rule))
+ return fix_port_and_protocol(perm)
+
+
+def rule_from_group_permission(perm):
+ def ports_from_permission(p):
+ if 'FromPort' not in p and 'ToPort' not in p:
+ return (None, None)
+ return (int(perm['FromPort']), int(perm['ToPort']))
+
+ # outputs a rule tuple
+ for target_key, target_subkey, target_type in [
+ ('IpRanges', 'CidrIp', 'ipv4'),
+ ('Ipv6Ranges', 'CidrIpv6', 'ipv6'),
+ ('PrefixListIds', 'PrefixListId', 'ip_prefix'),
+ ]:
+ if target_key not in perm:
+ continue
+ for r in perm[target_key]:
+ # there may be several IP ranges here, which is ok
+ yield Rule(
+ ports_from_permission(perm),
+ to_text(perm['IpProtocol']),
+ r[target_subkey],
+ target_type,
+ r.get('Description')
+ )
+ if 'UserIdGroupPairs' in perm and perm['UserIdGroupPairs']:
+ for pair in perm['UserIdGroupPairs']:
+ target = (
+ pair.get('UserId', None),
+ pair.get('GroupId', None),
+ pair.get('GroupName', None),
+ )
+ if pair.get('UserId', '').startswith('amazon-'):
+ # amazon-elb and amazon-prefix rules don't need
+ # group-id specified, so remove it when querying
+ # from permission
+ target = (
+ target[0],
+ None,
+ target[2],
+ )
+ elif 'VpcPeeringConnectionId' in pair or pair['UserId'] != current_account_id:
+ target = (
+ pair.get('UserId', None),
+ pair.get('GroupId', None),
+ pair.get('GroupName', None),
+ )
+
+ yield Rule(
+ ports_from_permission(perm),
+ to_text(perm['IpProtocol']),
+ target,
+ 'group',
+ pair.get('Description')
+ )
+
+
+@AWSRetry.backoff(tries=5, delay=5, backoff=2.0, catch_extra_error_codes=['InvalidGroup.NotFound'])
+def get_security_groups_with_backoff(connection, **kwargs):
+ return connection.describe_security_groups(**kwargs)
+
+
+@AWSRetry.backoff(tries=5, delay=5, backoff=2.0)
+def sg_exists_with_backoff(connection, **kwargs):
+ try:
+ return connection.describe_security_groups(**kwargs)
+ except is_boto3_error_code('InvalidGroup.NotFound'):
+ return {'SecurityGroups': []}
+
+
+def deduplicate_rules_args(rules):
+ """Returns unique rules"""
+ if rules is None:
+ return None
+ return list(dict(zip((json.dumps(r, sort_keys=True) for r in rules), rules)).values())
+
+
+def validate_rule(module, rule):
+ VALID_PARAMS = ('cidr_ip', 'cidr_ipv6', 'ip_prefix',
+ 'group_id', 'group_name', 'group_desc',
+ 'proto', 'from_port', 'to_port', 'rule_desc')
+ if not isinstance(rule, dict):
+ module.fail_json(msg='Invalid rule parameter type [%s].' % type(rule))
+ for k in rule:
+ if k not in VALID_PARAMS:
+ module.fail_json(msg='Invalid rule parameter \'{0}\' for rule: {1}'.format(k, rule))
+
+ if 'group_id' in rule and 'cidr_ip' in rule:
+ module.fail_json(msg='Specify group_id OR cidr_ip, not both')
+ elif 'group_name' in rule and 'cidr_ip' in rule:
+ module.fail_json(msg='Specify group_name OR cidr_ip, not both')
+ elif 'group_id' in rule and 'cidr_ipv6' in rule:
+ module.fail_json(msg="Specify group_id OR cidr_ipv6, not both")
+ elif 'group_name' in rule and 'cidr_ipv6' in rule:
+ module.fail_json(msg="Specify group_name OR cidr_ipv6, not both")
+ elif 'cidr_ip' in rule and 'cidr_ipv6' in rule:
+ module.fail_json(msg="Specify cidr_ip OR cidr_ipv6, not both")
+ elif 'group_id' in rule and 'group_name' in rule:
+ module.fail_json(msg='Specify group_id OR group_name, not both')
+
+
+def get_target_from_rule(module, client, rule, name, group, groups, vpc_id):
+ """
+ Returns tuple of (target_type, target, group_created) after validating rule params.
+
+ rule: Dict describing a rule.
+ name: Name of the security group being managed.
+ groups: Dict of all available security groups.
+
+ AWS accepts an ip range or a security group as target of a rule. This
+ function validate the rule specification and return either a non-None
+ group_id or a non-None ip range.
+ """
+ FOREIGN_SECURITY_GROUP_REGEX = r'^([^/]+)/?(sg-\S+)?/(\S+)'
+ group_id = None
+ group_name = None
+ target_group_created = False
+
+ validate_rule(module, rule)
+ if rule.get('group_id') and re.match(FOREIGN_SECURITY_GROUP_REGEX, rule['group_id']):
+ # this is a foreign Security Group. Since you can't fetch it you must create an instance of it
+ owner_id, group_id, group_name = re.match(FOREIGN_SECURITY_GROUP_REGEX, rule['group_id']).groups()
+ group_instance = dict(UserId=owner_id, GroupId=group_id, GroupName=group_name)
+ groups[group_id] = group_instance
+ groups[group_name] = group_instance
+ # group_id/group_name are mutually exclusive - give group_id more precedence as it is more specific
+ if group_id and group_name:
+ group_name = None
+ return 'group', (owner_id, group_id, group_name), False
+ elif 'group_id' in rule:
+ return 'group', rule['group_id'], False
+ elif 'group_name' in rule:
+ group_name = rule['group_name']
+ if group_name == name:
+ group_id = group['GroupId']
+ groups[group_id] = group
+ groups[group_name] = group
+ elif group_name in groups and group.get('VpcId') and groups[group_name].get('VpcId'):
+ # both are VPC groups, this is ok
+ group_id = groups[group_name]['GroupId']
+ elif group_name in groups and not (group.get('VpcId') or groups[group_name].get('VpcId')):
+ # both are EC2 classic, this is ok
+ group_id = groups[group_name]['GroupId']
+ else:
+ auto_group = None
+ filters = {'group-name': group_name}
+ if vpc_id:
+ filters['vpc-id'] = vpc_id
+ # if we got here, either the target group does not exist, or there
+ # is a mix of EC2 classic + VPC groups. Mixing of EC2 classic + VPC
+ # is bad, so we have to create a new SG because no compatible group
+ # exists
+ if not rule.get('group_desc', '').strip():
+ # retry describing the group once
+ try:
+ auto_group = get_security_groups_with_backoff(client, Filters=ansible_dict_to_boto3_filter_list(filters)).get('SecurityGroups', [])[0]
+ except (is_boto3_error_code('InvalidGroup.NotFound'), IndexError):
+ module.fail_json(msg="group %s will be automatically created by rule %s but "
+ "no description was provided" % (group_name, rule))
+ except ClientError as e: # pylint: disable=duplicate-except
+ module.fail_json_aws(e)
+ elif not module.check_mode:
+ params = dict(GroupName=group_name, Description=rule['group_desc'])
+ if vpc_id:
+ params['VpcId'] = vpc_id
+ try:
+ auto_group = client.create_security_group(**params)
+ get_waiter(
+ client, 'security_group_exists',
+ ).wait(
+ GroupIds=[auto_group['GroupId']],
+ )
+ except is_boto3_error_code('InvalidGroup.Duplicate'):
+ # The group exists, but didn't show up in any of our describe-security-groups calls
+ # Try searching on a filter for the name, and allow a retry window for AWS to update
+ # the model on their end.
+ try:
+ auto_group = get_security_groups_with_backoff(client, Filters=ansible_dict_to_boto3_filter_list(filters)).get('SecurityGroups', [])[0]
+ except IndexError as e:
+ module.fail_json(msg="Could not create or use existing group '{0}' in rule. Make sure the group exists".format(group_name))
+ except ClientError as e:
+ module.fail_json_aws(
+ e,
+ msg="Could not create or use existing group '{0}' in rule. Make sure the group exists".format(group_name))
+ if auto_group is not None:
+ group_id = auto_group['GroupId']
+ groups[group_id] = auto_group
+ groups[group_name] = auto_group
+ target_group_created = True
+ return 'group', group_id, target_group_created
+ elif 'cidr_ip' in rule:
+ return 'ipv4', validate_ip(module, rule['cidr_ip']), False
+ elif 'cidr_ipv6' in rule:
+ return 'ipv6', validate_ip(module, rule['cidr_ipv6']), False
+ elif 'ip_prefix' in rule:
+ return 'ip_prefix', rule['ip_prefix'], False
+
+ module.fail_json(msg="Could not match target for rule {0}".format(rule), failed_rule=rule)
+
+
+def ports_expand(ports):
+ # takes a list of ports and returns a list of (port_from, port_to)
+ ports_expanded = []
+ for port in ports:
+ if not isinstance(port, string_types):
+ ports_expanded.append((port,) * 2)
+ elif '-' in port:
+ ports_expanded.append(tuple(int(p.strip()) for p in port.split('-', 1)))
+ else:
+ ports_expanded.append((int(port.strip()),) * 2)
+
+ return ports_expanded
+
+
+def rule_expand_ports(rule):
+ # takes a rule dict and returns a list of expanded rule dicts
+ if 'ports' not in rule:
+ if isinstance(rule.get('from_port'), string_types):
+ rule['from_port'] = int(rule.get('from_port'))
+ if isinstance(rule.get('to_port'), string_types):
+ rule['to_port'] = int(rule.get('to_port'))
+ return [rule]
+
+ ports = rule['ports'] if isinstance(rule['ports'], list) else [rule['ports']]
+
+ rule_expanded = []
+ for from_to in ports_expand(ports):
+ temp_rule = rule.copy()
+ del temp_rule['ports']
+ temp_rule['from_port'], temp_rule['to_port'] = sorted(from_to)
+ rule_expanded.append(temp_rule)
+
+ return rule_expanded
+
+
+def rules_expand_ports(rules):
+ # takes a list of rules and expands it based on 'ports'
+ if not rules:
+ return rules
+
+ return [rule for rule_complex in rules
+ for rule in rule_expand_ports(rule_complex)]
+
+
+def rule_expand_source(rule, source_type):
+ # takes a rule dict and returns a list of expanded rule dicts for specified source_type
+ sources = rule[source_type] if isinstance(rule[source_type], list) else [rule[source_type]]
+ source_types_all = ('cidr_ip', 'cidr_ipv6', 'group_id', 'group_name', 'ip_prefix')
+
+ rule_expanded = []
+ for source in sources:
+ temp_rule = rule.copy()
+ for s in source_types_all:
+ temp_rule.pop(s, None)
+ temp_rule[source_type] = source
+ rule_expanded.append(temp_rule)
+
+ return rule_expanded
+
+
+def rule_expand_sources(rule):
+ # takes a rule dict and returns a list of expanded rule discts
+ source_types = (stype for stype in ('cidr_ip', 'cidr_ipv6', 'group_id', 'group_name', 'ip_prefix') if stype in rule)
+
+ return [r for stype in source_types
+ for r in rule_expand_source(rule, stype)]
+
+
+def rules_expand_sources(rules):
+ # takes a list of rules and expands it based on 'cidr_ip', 'group_id', 'group_name'
+ if not rules:
+ return rules
+
+ return [rule for rule_complex in rules
+ for rule in rule_expand_sources(rule_complex)]
+
+
+def update_rules_description(module, client, rule_type, group_id, ip_permissions):
+ if module.check_mode:
+ return
+ try:
+ if rule_type == "in":
+ client.update_security_group_rule_descriptions_ingress(GroupId=group_id, IpPermissions=ip_permissions)
+ if rule_type == "out":
+ client.update_security_group_rule_descriptions_egress(GroupId=group_id, IpPermissions=ip_permissions)
+ except (ClientError, BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Unable to update rule description for group %s" % group_id)
+
+
+def fix_port_and_protocol(permission):
+ for key in ('FromPort', 'ToPort'):
+ if key in permission:
+ if permission[key] is None:
+ del permission[key]
+ else:
+ permission[key] = int(permission[key])
+
+ permission['IpProtocol'] = to_text(permission['IpProtocol'])
+
+ return permission
+
+
+def remove_old_permissions(client, module, revoke_ingress, revoke_egress, group_id):
+ if revoke_ingress:
+ revoke(client, module, revoke_ingress, group_id, 'in')
+ if revoke_egress:
+ revoke(client, module, revoke_egress, group_id, 'out')
+ return bool(revoke_ingress or revoke_egress)
+
+
+def revoke(client, module, ip_permissions, group_id, rule_type):
+ if not module.check_mode:
+ try:
+ if rule_type == 'in':
+ client.revoke_security_group_ingress(GroupId=group_id, IpPermissions=ip_permissions)
+ elif rule_type == 'out':
+ client.revoke_security_group_egress(GroupId=group_id, IpPermissions=ip_permissions)
+ except (BotoCoreError, ClientError) as e:
+ rules = 'ingress rules' if rule_type == 'in' else 'egress rules'
+ module.fail_json_aws(e, "Unable to revoke {0}: {1}".format(rules, ip_permissions))
+
+
+def add_new_permissions(client, module, new_ingress, new_egress, group_id):
+ if new_ingress:
+ authorize(client, module, new_ingress, group_id, 'in')
+ if new_egress:
+ authorize(client, module, new_egress, group_id, 'out')
+ return bool(new_ingress or new_egress)
+
+
+def authorize(client, module, ip_permissions, group_id, rule_type):
+ if not module.check_mode:
+ try:
+ if rule_type == 'in':
+ client.authorize_security_group_ingress(GroupId=group_id, IpPermissions=ip_permissions)
+ elif rule_type == 'out':
+ client.authorize_security_group_egress(GroupId=group_id, IpPermissions=ip_permissions)
+ except (BotoCoreError, ClientError) as e:
+ rules = 'ingress rules' if rule_type == 'in' else 'egress rules'
+ module.fail_json_aws(e, "Unable to authorize {0}: {1}".format(rules, ip_permissions))
+
+
+def validate_ip(module, cidr_ip):
+ split_addr = cidr_ip.split('/')
+ if len(split_addr) == 2:
+ # this_ip is a IPv4 or IPv6 CIDR that may or may not have host bits set
+ # Get the network bits if IPv4, and validate if IPv6.
+ try:
+ ip = to_subnet(split_addr[0], split_addr[1])
+ if ip != cidr_ip:
+ module.warn("One of your CIDR addresses ({0}) has host bits set. To get rid of this warning, "
+ "check the network mask and make sure that only network bits are set: {1}.".format(
+ cidr_ip, ip))
+ except ValueError:
+ # to_subnet throws a ValueError on IPv6 networks, so we should be working with v6 if we get here
+ try:
+ isinstance(ip_network(to_text(cidr_ip)), IPv6Network)
+ ip = cidr_ip
+ except ValueError:
+ # If a host bit is set on something other than a /128, IPv6Network will throw a ValueError
+ # The ipv6_cidr in this case probably looks like "2001:DB8:A0B:12F0::1/64" and we just want the network bits
+ ip6 = to_ipv6_subnet(split_addr[0]) + "/" + split_addr[1]
+ if ip6 != cidr_ip:
+ module.warn("One of your IPv6 CIDR addresses ({0}) has host bits set. To get rid of this warning, "
+ "check the network mask and make sure that only network bits are set: {1}.".format(cidr_ip, ip6))
+ return ip6
+ return ip
+ return cidr_ip
+
+
+def update_tags(client, module, group_id, current_tags, tags, purge_tags):
+ tags_need_modify, tags_to_delete = compare_aws_tags(current_tags, tags, purge_tags)
+
+ if not module.check_mode:
+ if tags_to_delete:
+ try:
+ client.delete_tags(Resources=[group_id], Tags=[{'Key': tag} for tag in tags_to_delete])
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Unable to delete tags {0}".format(tags_to_delete))
+
+ # Add/update tags
+ if tags_need_modify:
+ try:
+ client.create_tags(Resources=[group_id], Tags=ansible_dict_to_boto3_tag_list(tags_need_modify))
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json(e, msg="Unable to add tags {0}".format(tags_need_modify))
+
+ return bool(tags_need_modify or tags_to_delete)
+
+
+def update_rule_descriptions(module, group_id, present_ingress, named_tuple_ingress_list, present_egress, named_tuple_egress_list):
+ changed = False
+ client = module.client('ec2')
+ ingress_needs_desc_update = []
+ egress_needs_desc_update = []
+
+ for present_rule in present_egress:
+ needs_update = [r for r in named_tuple_egress_list if rule_cmp(r, present_rule) and r.description != present_rule.description]
+ for r in needs_update:
+ named_tuple_egress_list.remove(r)
+ egress_needs_desc_update.extend(needs_update)
+ for present_rule in present_ingress:
+ needs_update = [r for r in named_tuple_ingress_list if rule_cmp(r, present_rule) and r.description != present_rule.description]
+ for r in needs_update:
+ named_tuple_ingress_list.remove(r)
+ ingress_needs_desc_update.extend(needs_update)
+
+ if ingress_needs_desc_update:
+ update_rules_description(module, client, 'in', group_id, rules_to_permissions(ingress_needs_desc_update))
+ changed |= True
+ if egress_needs_desc_update:
+ update_rules_description(module, client, 'out', group_id, rules_to_permissions(egress_needs_desc_update))
+ changed |= True
+ return changed
+
+
+def create_security_group(client, module, name, description, vpc_id):
+ if not module.check_mode:
+ params = dict(GroupName=name, Description=description)
+ if vpc_id:
+ params['VpcId'] = vpc_id
+ try:
+ group = client.create_security_group(**params)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Unable to create security group")
+ # When a group is created, an egress_rule ALLOW ALL
+ # to 0.0.0.0/0 is added automatically but it's not
+ # reflected in the object returned by the AWS API
+ # call. We re-read the group for getting an updated object
+ # amazon sometimes takes a couple seconds to update the security group so wait till it exists
+ while True:
+ sleep(3)
+ group = get_security_groups_with_backoff(client, GroupIds=[group['GroupId']])['SecurityGroups'][0]
+ if group.get('VpcId') and not group.get('IpPermissionsEgress'):
+ pass
+ else:
+ break
+ return group
+ return None
+
+
+def wait_for_rule_propagation(module, group, desired_ingress, desired_egress, purge_ingress, purge_egress):
+ group_id = group['GroupId']
+ tries = 6
+
+ def await_rules(group, desired_rules, purge, rule_key):
+ for i in range(tries):
+ current_rules = set(sum([list(rule_from_group_permission(p)) for p in group[rule_key]], []))
+ if purge and len(current_rules ^ set(desired_rules)) == 0:
+ return group
+ elif purge:
+ conflicts = current_rules ^ set(desired_rules)
+ # For cases where set comparison is equivalent, but invalid port/proto exist
+ for a, b in itertools.combinations(conflicts, 2):
+ if rule_cmp(a, b):
+ conflicts.discard(a)
+ conflicts.discard(b)
+ if not len(conflicts):
+ return group
+ elif current_rules.issuperset(desired_rules) and not purge:
+ return group
+ sleep(10)
+ group = get_security_groups_with_backoff(module.client('ec2'), GroupIds=[group_id])['SecurityGroups'][0]
+ module.warn("Ran out of time waiting for {0} {1}. Current: {2}, Desired: {3}".format(group_id, rule_key, current_rules, desired_rules))
+ return group
+
+ group = get_security_groups_with_backoff(module.client('ec2'), GroupIds=[group_id])['SecurityGroups'][0]
+ if 'VpcId' in group and module.params.get('rules_egress') is not None:
+ group = await_rules(group, desired_egress, purge_egress, 'IpPermissionsEgress')
+ return await_rules(group, desired_ingress, purge_ingress, 'IpPermissions')
+
+
+def group_exists(client, module, vpc_id, group_id, name):
+ params = {'Filters': []}
+ if group_id:
+ params['GroupIds'] = [group_id]
+ if name:
+ # Add name to filters rather than params['GroupNames']
+ # because params['GroupNames'] only checks the default vpc if no vpc is provided
+ params['Filters'].append({'Name': 'group-name', 'Values': [name]})
+ if vpc_id:
+ params['Filters'].append({'Name': 'vpc-id', 'Values': [vpc_id]})
+ # Don't filter by description to maintain backwards compatibility
+
+ try:
+ security_groups = sg_exists_with_backoff(client, **params).get('SecurityGroups', [])
+ all_groups = get_security_groups_with_backoff(client).get('SecurityGroups', [])
+ except (BotoCoreError, ClientError) as e: # pylint: disable=duplicate-except
+ module.fail_json_aws(e, msg="Error in describe_security_groups")
+
+ if security_groups:
+ groups = dict((group['GroupId'], group) for group in all_groups)
+ groups.update(dict((group['GroupName'], group) for group in all_groups))
+ if vpc_id:
+ vpc_wins = dict((group['GroupName'], group) for group in all_groups if group.get('VpcId') and group['VpcId'] == vpc_id)
+ groups.update(vpc_wins)
+ # maintain backwards compatibility by using the last matching group
+ return security_groups[-1], groups
+ return None, {}
+
+
+def verify_rules_with_descriptions_permitted(client, module, rules, rules_egress):
+ if not hasattr(client, "update_security_group_rule_descriptions_egress"):
+ all_rules = rules if rules else [] + rules_egress if rules_egress else []
+ if any('rule_desc' in rule for rule in all_rules):
+ module.fail_json(msg="Using rule descriptions requires botocore version >= 1.7.2.")
+
+
+def get_diff_final_resource(client, module, security_group):
+ def get_account_id(security_group, module):
+ try:
+ owner_id = security_group.get('owner_id', module.client('sts').get_caller_identity()['Account'])
+ except (BotoCoreError, ClientError) as e:
+ owner_id = "Unable to determine owner_id: {0}".format(to_text(e))
+ return owner_id
+
+ def get_final_tags(security_group_tags, specified_tags, purge_tags):
+ if specified_tags is None:
+ return security_group_tags
+ tags_need_modify, tags_to_delete = compare_aws_tags(security_group_tags, specified_tags, purge_tags)
+ end_result_tags = dict((k, v) for k, v in specified_tags.items() if k not in tags_to_delete)
+ end_result_tags.update(dict((k, v) for k, v in security_group_tags.items() if k not in tags_to_delete))
+ end_result_tags.update(tags_need_modify)
+ return end_result_tags
+
+ def get_final_rules(client, module, security_group_rules, specified_rules, purge_rules):
+ if specified_rules is None:
+ return security_group_rules
+ if purge_rules:
+ final_rules = []
+ else:
+ final_rules = list(security_group_rules)
+ specified_rules = flatten_nested_targets(module, deepcopy(specified_rules))
+ for rule in specified_rules:
+ format_rule = {
+ 'from_port': None, 'to_port': None, 'ip_protocol': rule.get('proto', 'tcp'),
+ 'ip_ranges': [], 'ipv6_ranges': [], 'prefix_list_ids': [], 'user_id_group_pairs': []
+ }
+ if rule.get('proto', 'tcp') in ('all', '-1', -1):
+ format_rule['ip_protocol'] = '-1'
+ format_rule.pop('from_port')
+ format_rule.pop('to_port')
+ elif rule.get('ports'):
+ if rule.get('ports') and (isinstance(rule['ports'], string_types) or isinstance(rule['ports'], int)):
+ rule['ports'] = [rule['ports']]
+ for port in rule.get('ports'):
+ if isinstance(port, string_types) and '-' in port:
+ format_rule['from_port'], format_rule['to_port'] = port.split('-')
+ else:
+ format_rule['from_port'] = format_rule['to_port'] = port
+ elif rule.get('from_port') or rule.get('to_port'):
+ format_rule['from_port'] = rule.get('from_port', rule.get('to_port'))
+ format_rule['to_port'] = rule.get('to_port', rule.get('from_port'))
+ for source_type in ('cidr_ip', 'cidr_ipv6', 'prefix_list_id'):
+ if rule.get(source_type):
+ rule_key = {'cidr_ip': 'ip_ranges', 'cidr_ipv6': 'ipv6_ranges', 'prefix_list_id': 'prefix_list_ids'}.get(source_type)
+ if rule.get('rule_desc'):
+ format_rule[rule_key] = [{source_type: rule[source_type], 'description': rule['rule_desc']}]
+ else:
+ if not isinstance(rule[source_type], list):
+ rule[source_type] = [rule[source_type]]
+ format_rule[rule_key] = [{source_type: target} for target in rule[source_type]]
+ if rule.get('group_id') or rule.get('group_name'):
+ rule_sg = camel_dict_to_snake_dict(group_exists(client, module, module.params['vpc_id'], rule.get('group_id'), rule.get('group_name'))[0])
+ format_rule['user_id_group_pairs'] = [{
+ 'description': rule_sg.get('description', rule_sg.get('group_desc')),
+ 'group_id': rule_sg.get('group_id', rule.get('group_id')),
+ 'group_name': rule_sg.get('group_name', rule.get('group_name')),
+ 'peering_status': rule_sg.get('peering_status'),
+ 'user_id': rule_sg.get('user_id', get_account_id(security_group, module)),
+ 'vpc_id': rule_sg.get('vpc_id', module.params['vpc_id']),
+ 'vpc_peering_connection_id': rule_sg.get('vpc_peering_connection_id')
+ }]
+ for k, v in list(format_rule['user_id_group_pairs'][0].items()):
+ if v is None:
+ format_rule['user_id_group_pairs'][0].pop(k)
+ final_rules.append(format_rule)
+ # Order final rules consistently
+ final_rules.sort(key=get_ip_permissions_sort_key)
+ return final_rules
+ security_group_ingress = security_group.get('ip_permissions', [])
+ specified_ingress = module.params['rules']
+ purge_ingress = module.params['purge_rules']
+ security_group_egress = security_group.get('ip_permissions_egress', [])
+ specified_egress = module.params['rules_egress']
+ purge_egress = module.params['purge_rules_egress']
+ return {
+ 'description': module.params['description'],
+ 'group_id': security_group.get('group_id', 'sg-xxxxxxxx'),
+ 'group_name': security_group.get('group_name', module.params['name']),
+ 'ip_permissions': get_final_rules(client, module, security_group_ingress, specified_ingress, purge_ingress),
+ 'ip_permissions_egress': get_final_rules(client, module, security_group_egress, specified_egress, purge_egress),
+ 'owner_id': get_account_id(security_group, module),
+ 'tags': get_final_tags(security_group.get('tags', {}), module.params['tags'], module.params['purge_tags']),
+ 'vpc_id': security_group.get('vpc_id', module.params['vpc_id'])}
+
+
+def flatten_nested_targets(module, rules):
+ def _flatten(targets):
+ for target in targets:
+ if isinstance(target, list):
+ for t in _flatten(target):
+ yield t
+ elif isinstance(target, string_types):
+ yield target
+
+ if rules is not None:
+ for rule in rules:
+ target_list_type = None
+ if isinstance(rule.get('cidr_ip'), list):
+ target_list_type = 'cidr_ip'
+ elif isinstance(rule.get('cidr_ipv6'), list):
+ target_list_type = 'cidr_ipv6'
+ if target_list_type is not None:
+ rule[target_list_type] = list(_flatten(rule[target_list_type]))
+ return rules
+
+
+def get_rule_sort_key(dicts):
+ if dicts.get('cidr_ip'):
+ return dicts.get('cidr_ip')
+ elif dicts.get('cidr_ipv6'):
+ return dicts.get('cidr_ipv6')
+ elif dicts.get('prefix_list_id'):
+ return dicts.get('prefix_list_id')
+ elif dicts.get('group_id'):
+ return dicts.get('group_id')
+ return None
+
+
+def get_ip_permissions_sort_key(rule):
+ if rule.get('ip_ranges'):
+ rule.get('ip_ranges').sort(key=get_rule_sort_key)
+ return rule.get('ip_ranges')[0]['cidr_ip']
+ elif rule.get('ipv6_ranges'):
+ rule.get('ipv6_ranges').sort(key=get_rule_sort_key)
+ return rule.get('ipv6_ranges')[0]['cidr_ipv6']
+ elif rule.get('prefix_list_ids'):
+ rule.get('prefix_list_ids').sort(key=get_rule_sort_key)
+ return rule.get('prefix_list_ids')[0]['prefix_list_id']
+ elif rule.get('user_id_group_pairs'):
+ rule.get('user_id_group_pairs').sort(key=get_rule_sort_key)
+ return rule.get('user_id_group_pairs')[0]['group_id']
+ return None
+
+
+def main():
+ argument_spec = dict(
+ name=dict(),
+ group_id=dict(),
+ description=dict(),
+ vpc_id=dict(),
+ rules=dict(type='list'),
+ rules_egress=dict(type='list'),
+ state=dict(default='present', type='str', choices=['present', 'absent']),
+ purge_rules=dict(default=True, required=False, type='bool'),
+ purge_rules_egress=dict(default=True, required=False, type='bool'),
+ tags=dict(required=False, type='dict', aliases=['resource_tags']),
+ purge_tags=dict(default=True, required=False, type='bool')
+ )
+ module = AnsibleAWSModule(
+ argument_spec=argument_spec,
+ supports_check_mode=True,
+ required_one_of=[['name', 'group_id']],
+ required_if=[['state', 'present', ['name']]],
+ )
+
+ name = module.params['name']
+ group_id = module.params['group_id']
+ description = module.params['description']
+ vpc_id = module.params['vpc_id']
+ rules = flatten_nested_targets(module, deepcopy(module.params['rules']))
+ rules_egress = flatten_nested_targets(module, deepcopy(module.params['rules_egress']))
+ rules = deduplicate_rules_args(rules_expand_sources(rules_expand_ports(rules)))
+ rules_egress = deduplicate_rules_args(rules_expand_sources(rules_expand_ports(rules_egress)))
+ state = module.params.get('state')
+ purge_rules = module.params['purge_rules']
+ purge_rules_egress = module.params['purge_rules_egress']
+ tags = module.params['tags']
+ purge_tags = module.params['purge_tags']
+
+ if state == 'present' and not description:
+ module.fail_json(msg='Must provide description when state is present.')
+
+ changed = False
+ client = module.client('ec2')
+
+ verify_rules_with_descriptions_permitted(client, module, rules, rules_egress)
+ group, groups = group_exists(client, module, vpc_id, group_id, name)
+ group_created_new = not bool(group)
+
+ global current_account_id
+ current_account_id = get_aws_account_id(module)
+
+ before = {}
+ after = {}
+
+ # Ensure requested group is absent
+ if state == 'absent':
+ if group:
+ # found a match, delete it
+ before = camel_dict_to_snake_dict(group, ignore_list=['Tags'])
+ before['tags'] = boto3_tag_list_to_ansible_dict(before.get('tags', []))
+ try:
+ if not module.check_mode:
+ client.delete_security_group(GroupId=group['GroupId'])
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Unable to delete security group '%s'" % group)
+ else:
+ group = None
+ changed = True
+ else:
+ # no match found, no changes required
+ pass
+
+ # Ensure requested group is present
+ elif state == 'present':
+ if group:
+ # existing group
+ before = camel_dict_to_snake_dict(group, ignore_list=['Tags'])
+ before['tags'] = boto3_tag_list_to_ansible_dict(before.get('tags', []))
+ if group['Description'] != description:
+ module.warn("Group description does not match existing group. Descriptions cannot be changed without deleting "
+ "and re-creating the security group. Try using state=absent to delete, then rerunning this task.")
+ else:
+ # no match found, create it
+ group = create_security_group(client, module, name, description, vpc_id)
+ changed = True
+
+ if tags is not None and group is not None:
+ current_tags = boto3_tag_list_to_ansible_dict(group.get('Tags', []))
+ changed |= update_tags(client, module, group['GroupId'], current_tags, tags, purge_tags)
+
+ if group:
+ named_tuple_ingress_list = []
+ named_tuple_egress_list = []
+ current_ingress = sum([list(rule_from_group_permission(p)) for p in group['IpPermissions']], [])
+ current_egress = sum([list(rule_from_group_permission(p)) for p in group['IpPermissionsEgress']], [])
+
+ for new_rules, rule_type, named_tuple_rule_list in [(rules, 'in', named_tuple_ingress_list),
+ (rules_egress, 'out', named_tuple_egress_list)]:
+ if new_rules is None:
+ continue
+ for rule in new_rules:
+ target_type, target, target_group_created = get_target_from_rule(
+ module, client, rule, name, group, groups, vpc_id)
+ changed |= target_group_created
+
+ if rule.get('proto', 'tcp') in ('all', '-1', -1):
+ rule['proto'] = '-1'
+ rule['from_port'] = None
+ rule['to_port'] = None
+ try:
+ int(rule.get('proto', 'tcp'))
+ rule['proto'] = to_text(rule.get('proto', 'tcp'))
+ rule['from_port'] = None
+ rule['to_port'] = None
+ except ValueError:
+ # rule does not use numeric protocol spec
+ pass
+
+ named_tuple_rule_list.append(
+ Rule(
+ port_range=(rule['from_port'], rule['to_port']),
+ protocol=to_text(rule.get('proto', 'tcp')),
+ target=target, target_type=target_type,
+ description=rule.get('rule_desc'),
+ )
+ )
+
+ # List comprehensions for rules to add, rules to modify, and rule ids to determine purging
+ new_ingress_permissions = [to_permission(r) for r in (set(named_tuple_ingress_list) - set(current_ingress))]
+ new_egress_permissions = [to_permission(r) for r in (set(named_tuple_egress_list) - set(current_egress))]
+
+ if module.params.get('rules_egress') is None and 'VpcId' in group:
+ # when no egress rules are specified and we're in a VPC,
+ # we add in a default allow all out rule, which was the
+ # default behavior before egress rules were added
+ rule = Rule((None, None), '-1', '0.0.0.0/0', 'ipv4', None)
+ if rule in current_egress:
+ named_tuple_egress_list.append(rule)
+ if rule not in current_egress:
+ current_egress.append(rule)
+
+ # List comprehensions for rules to add, rules to modify, and rule ids to determine purging
+ present_ingress = list(set(named_tuple_ingress_list).union(set(current_ingress)))
+ present_egress = list(set(named_tuple_egress_list).union(set(current_egress)))
+
+ if purge_rules:
+ revoke_ingress = []
+ for p in present_ingress:
+ if not any([rule_cmp(p, b) for b in named_tuple_ingress_list]):
+ revoke_ingress.append(to_permission(p))
+ else:
+ revoke_ingress = []
+ if purge_rules_egress and module.params.get('rules_egress') is not None:
+ if module.params.get('rules_egress') is []:
+ revoke_egress = [
+ to_permission(r) for r in set(present_egress) - set(named_tuple_egress_list)
+ if r != Rule((None, None), '-1', '0.0.0.0/0', 'ipv4', None)
+ ]
+ else:
+ revoke_egress = []
+ for p in present_egress:
+ if not any([rule_cmp(p, b) for b in named_tuple_egress_list]):
+ revoke_egress.append(to_permission(p))
+ else:
+ revoke_egress = []
+
+ # named_tuple_ingress_list and named_tuple_egress_list got updated by
+ # method update_rule_descriptions, deep copy these two lists to new
+ # variables for the record of the 'desired' ingress and egress sg permissions
+ desired_ingress = deepcopy(named_tuple_ingress_list)
+ desired_egress = deepcopy(named_tuple_egress_list)
+
+ changed |= update_rule_descriptions(module, group['GroupId'], present_ingress, named_tuple_ingress_list, present_egress, named_tuple_egress_list)
+
+ # Revoke old rules
+ changed |= remove_old_permissions(client, module, revoke_ingress, revoke_egress, group['GroupId'])
+ rule_msg = 'Revoking {0}, and egress {1}'.format(revoke_ingress, revoke_egress)
+
+ new_ingress_permissions = [to_permission(r) for r in (set(named_tuple_ingress_list) - set(current_ingress))]
+ new_ingress_permissions = rules_to_permissions(set(named_tuple_ingress_list) - set(current_ingress))
+ new_egress_permissions = rules_to_permissions(set(named_tuple_egress_list) - set(current_egress))
+ # Authorize new rules
+ changed |= add_new_permissions(client, module, new_ingress_permissions, new_egress_permissions, group['GroupId'])
+
+ if group_created_new and module.params.get('rules') is None and module.params.get('rules_egress') is None:
+ # A new group with no rules provided is already being awaited.
+ # When it is created we wait for the default egress rule to be added by AWS
+ security_group = get_security_groups_with_backoff(client, GroupIds=[group['GroupId']])['SecurityGroups'][0]
+ elif changed and not module.check_mode:
+ # keep pulling until current security group rules match the desired ingress and egress rules
+ security_group = wait_for_rule_propagation(module, group, desired_ingress, desired_egress, purge_rules, purge_rules_egress)
+ else:
+ security_group = get_security_groups_with_backoff(client, GroupIds=[group['GroupId']])['SecurityGroups'][0]
+ security_group = camel_dict_to_snake_dict(security_group, ignore_list=['Tags'])
+ security_group['tags'] = boto3_tag_list_to_ansible_dict(security_group.get('tags', []))
+
+ else:
+ security_group = {'group_id': None}
+
+ if module._diff:
+ if module.params['state'] == 'present':
+ after = get_diff_final_resource(client, module, security_group)
+ if before.get('ip_permissions'):
+ before['ip_permissions'].sort(key=get_ip_permissions_sort_key)
+
+ security_group['diff'] = [{'before': before, 'after': after}]
+
+ module.exit_json(changed=changed, **security_group)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/ec2_vpc_net.py b/test/support/integration/plugins/modules/ec2_vpc_net.py
new file mode 100644
index 00000000..30e4b1e9
--- /dev/null
+++ b/test/support/integration/plugins/modules/ec2_vpc_net.py
@@ -0,0 +1,524 @@
+#!/usr/bin/python
+# Copyright: Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+
+DOCUMENTATION = '''
+---
+module: ec2_vpc_net
+short_description: Configure AWS virtual private clouds
+description:
+ - Create, modify, and terminate AWS virtual private clouds.
+version_added: "2.0"
+author:
+ - Jonathan Davila (@defionscode)
+ - Sloane Hertel (@s-hertel)
+options:
+ name:
+ description:
+ - The name to give your VPC. This is used in combination with C(cidr_block) to determine if a VPC already exists.
+ required: yes
+ type: str
+ cidr_block:
+ description:
+ - The primary CIDR of the VPC. After 2.5 a list of CIDRs can be provided. The first in the list will be used as the primary CIDR
+ and is used in conjunction with the C(name) to ensure idempotence.
+ required: yes
+ type: list
+ elements: str
+ ipv6_cidr:
+ description:
+ - Request an Amazon-provided IPv6 CIDR block with /56 prefix length. You cannot specify the range of IPv6 addresses,
+ or the size of the CIDR block.
+ default: False
+ type: bool
+ version_added: '2.10'
+ purge_cidrs:
+ description:
+ - Remove CIDRs that are associated with the VPC and are not specified in C(cidr_block).
+ default: no
+ type: bool
+ version_added: '2.5'
+ tenancy:
+ description:
+ - Whether to be default or dedicated tenancy. This cannot be changed after the VPC has been created.
+ default: default
+ choices: [ 'default', 'dedicated' ]
+ type: str
+ dns_support:
+ description:
+ - Whether to enable AWS DNS support.
+ default: yes
+ type: bool
+ dns_hostnames:
+ description:
+ - Whether to enable AWS hostname support.
+ default: yes
+ type: bool
+ dhcp_opts_id:
+ description:
+ - The id of the DHCP options to use for this VPC.
+ type: str
+ tags:
+ description:
+ - The tags you want attached to the VPC. This is independent of the name value, note if you pass a 'Name' key it would override the Name of
+ the VPC if it's different.
+ aliases: [ 'resource_tags' ]
+ type: dict
+ state:
+ description:
+ - The state of the VPC. Either absent or present.
+ default: present
+ choices: [ 'present', 'absent' ]
+ type: str
+ multi_ok:
+ description:
+ - By default the module will not create another VPC if there is another VPC with the same name and CIDR block. Specify this as true if you want
+ duplicate VPCs created.
+ type: bool
+ default: false
+requirements:
+ - boto3
+ - botocore
+extends_documentation_fragment:
+ - aws
+ - ec2
+'''
+
+EXAMPLES = '''
+# Note: These examples do not set authentication details, see the AWS Guide for details.
+
+- name: create a VPC with dedicated tenancy and a couple of tags
+ ec2_vpc_net:
+ name: Module_dev2
+ cidr_block: 10.10.0.0/16
+ region: us-east-1
+ tags:
+ module: ec2_vpc_net
+ this: works
+ tenancy: dedicated
+
+- name: create a VPC with dedicated tenancy and request an IPv6 CIDR
+ ec2_vpc_net:
+ name: Module_dev2
+ cidr_block: 10.10.0.0/16
+ ipv6_cidr: True
+ region: us-east-1
+ tenancy: dedicated
+'''
+
+RETURN = '''
+vpc:
+ description: info about the VPC that was created or deleted
+ returned: always
+ type: complex
+ contains:
+ cidr_block:
+ description: The CIDR of the VPC
+ returned: always
+ type: str
+ sample: 10.0.0.0/16
+ cidr_block_association_set:
+ description: IPv4 CIDR blocks associated with the VPC
+ returned: success
+ type: list
+ sample:
+ "cidr_block_association_set": [
+ {
+ "association_id": "vpc-cidr-assoc-97aeeefd",
+ "cidr_block": "20.0.0.0/24",
+ "cidr_block_state": {
+ "state": "associated"
+ }
+ }
+ ]
+ classic_link_enabled:
+ description: indicates whether ClassicLink is enabled
+ returned: always
+ type: bool
+ sample: false
+ dhcp_options_id:
+ description: the id of the DHCP options associated with this VPC
+ returned: always
+ type: str
+ sample: dopt-0fb8bd6b
+ id:
+ description: VPC resource id
+ returned: always
+ type: str
+ sample: vpc-c2e00da5
+ instance_tenancy:
+ description: indicates whether VPC uses default or dedicated tenancy
+ returned: always
+ type: str
+ sample: default
+ ipv6_cidr_block_association_set:
+ description: IPv6 CIDR blocks associated with the VPC
+ returned: success
+ type: list
+ sample:
+ "ipv6_cidr_block_association_set": [
+ {
+ "association_id": "vpc-cidr-assoc-97aeeefd",
+ "ipv6_cidr_block": "2001:db8::/56",
+ "ipv6_cidr_block_state": {
+ "state": "associated"
+ }
+ }
+ ]
+ is_default:
+ description: indicates whether this is the default VPC
+ returned: always
+ type: bool
+ sample: false
+ state:
+ description: state of the VPC
+ returned: always
+ type: str
+ sample: available
+ tags:
+ description: tags attached to the VPC, includes name
+ returned: always
+ type: complex
+ contains:
+ Name:
+ description: name tag for the VPC
+ returned: always
+ type: str
+ sample: pk_vpc4
+'''
+
+try:
+ import botocore
+except ImportError:
+ pass # Handled by AnsibleAWSModule
+
+from time import sleep, time
+from ansible.module_utils.aws.core import AnsibleAWSModule
+from ansible.module_utils.ec2 import (AWSRetry, camel_dict_to_snake_dict, compare_aws_tags,
+ ansible_dict_to_boto3_tag_list, boto3_tag_list_to_ansible_dict)
+from ansible.module_utils.six import string_types
+from ansible.module_utils._text import to_native
+from ansible.module_utils.network.common.utils import to_subnet
+
+
+def vpc_exists(module, vpc, name, cidr_block, multi):
+ """Returns None or a vpc object depending on the existence of a VPC. When supplied
+ with a CIDR, it will check for matching tags to determine if it is a match
+ otherwise it will assume the VPC does not exist and thus return None.
+ """
+ try:
+ matching_vpcs = vpc.describe_vpcs(Filters=[{'Name': 'tag:Name', 'Values': [name]}, {'Name': 'cidr-block', 'Values': cidr_block}])['Vpcs']
+ # If an exact matching using a list of CIDRs isn't found, check for a match with the first CIDR as is documented for C(cidr_block)
+ if not matching_vpcs:
+ matching_vpcs = vpc.describe_vpcs(Filters=[{'Name': 'tag:Name', 'Values': [name]}, {'Name': 'cidr-block', 'Values': [cidr_block[0]]}])['Vpcs']
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to describe VPCs")
+
+ if multi:
+ return None
+ elif len(matching_vpcs) == 1:
+ return matching_vpcs[0]['VpcId']
+ elif len(matching_vpcs) > 1:
+ module.fail_json(msg='Currently there are %d VPCs that have the same name and '
+ 'CIDR block you specified. If you would like to create '
+ 'the VPC anyway please pass True to the multi_ok param.' % len(matching_vpcs))
+ return None
+
+
+@AWSRetry.backoff(delay=3, tries=8, catch_extra_error_codes=['InvalidVpcID.NotFound'])
+def get_classic_link_with_backoff(connection, vpc_id):
+ try:
+ return connection.describe_vpc_classic_link(VpcIds=[vpc_id])['Vpcs'][0].get('ClassicLinkEnabled')
+ except botocore.exceptions.ClientError as e:
+ if e.response["Error"]["Message"] == "The functionality you requested is not available in this region.":
+ return False
+ else:
+ raise
+
+
+def get_vpc(module, connection, vpc_id):
+ # wait for vpc to be available
+ try:
+ connection.get_waiter('vpc_available').wait(VpcIds=[vpc_id])
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Unable to wait for VPC {0} to be available.".format(vpc_id))
+
+ try:
+ vpc_obj = connection.describe_vpcs(VpcIds=[vpc_id], aws_retry=True)['Vpcs'][0]
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to describe VPCs")
+ try:
+ vpc_obj['ClassicLinkEnabled'] = get_classic_link_with_backoff(connection, vpc_id)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to describe VPCs")
+
+ return vpc_obj
+
+
+def update_vpc_tags(connection, module, vpc_id, tags, name):
+ if tags is None:
+ tags = dict()
+
+ tags.update({'Name': name})
+ tags = dict((k, to_native(v)) for k, v in tags.items())
+ try:
+ current_tags = dict((t['Key'], t['Value']) for t in connection.describe_tags(Filters=[{'Name': 'resource-id', 'Values': [vpc_id]}])['Tags'])
+ tags_to_update, dummy = compare_aws_tags(current_tags, tags, False)
+ if tags_to_update:
+ if not module.check_mode:
+ tags = ansible_dict_to_boto3_tag_list(tags_to_update)
+ vpc_obj = connection.create_tags(Resources=[vpc_id], Tags=tags, aws_retry=True)
+
+ # Wait for tags to be updated
+ expected_tags = boto3_tag_list_to_ansible_dict(tags)
+ filters = [{'Name': 'tag:{0}'.format(key), 'Values': [value]} for key, value in expected_tags.items()]
+ connection.get_waiter('vpc_available').wait(VpcIds=[vpc_id], Filters=filters)
+
+ return True
+ else:
+ return False
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to update tags")
+
+
+def update_dhcp_opts(connection, module, vpc_obj, dhcp_id):
+ if vpc_obj['DhcpOptionsId'] != dhcp_id:
+ if not module.check_mode:
+ try:
+ connection.associate_dhcp_options(DhcpOptionsId=dhcp_id, VpcId=vpc_obj['VpcId'])
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to associate DhcpOptionsId {0}".format(dhcp_id))
+
+ try:
+ # Wait for DhcpOptionsId to be updated
+ filters = [{'Name': 'dhcp-options-id', 'Values': [dhcp_id]}]
+ connection.get_waiter('vpc_available').wait(VpcIds=[vpc_obj['VpcId']], Filters=filters)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json(msg="Failed to wait for DhcpOptionsId to be updated")
+
+ return True
+ else:
+ return False
+
+
+def create_vpc(connection, module, cidr_block, tenancy):
+ try:
+ if not module.check_mode:
+ vpc_obj = connection.create_vpc(CidrBlock=cidr_block, InstanceTenancy=tenancy)
+ else:
+ module.exit_json(changed=True)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "Failed to create the VPC")
+
+ # wait for vpc to exist
+ try:
+ connection.get_waiter('vpc_exists').wait(VpcIds=[vpc_obj['Vpc']['VpcId']])
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Unable to wait for VPC {0} to be created.".format(vpc_obj['Vpc']['VpcId']))
+
+ return vpc_obj['Vpc']['VpcId']
+
+
+def wait_for_vpc_attribute(connection, module, vpc_id, attribute, expected_value):
+ start_time = time()
+ updated = False
+ while time() < start_time + 300:
+ current_value = connection.describe_vpc_attribute(
+ Attribute=attribute,
+ VpcId=vpc_id
+ )['{0}{1}'.format(attribute[0].upper(), attribute[1:])]['Value']
+ if current_value != expected_value:
+ sleep(3)
+ else:
+ updated = True
+ break
+ if not updated:
+ module.fail_json(msg="Failed to wait for {0} to be updated".format(attribute))
+
+
+def get_cidr_network_bits(module, cidr_block):
+ fixed_cidrs = []
+ for cidr in cidr_block:
+ split_addr = cidr.split('/')
+ if len(split_addr) == 2:
+ # this_ip is a IPv4 CIDR that may or may not have host bits set
+ # Get the network bits.
+ valid_cidr = to_subnet(split_addr[0], split_addr[1])
+ if cidr != valid_cidr:
+ module.warn("One of your CIDR addresses ({0}) has host bits set. To get rid of this warning, "
+ "check the network mask and make sure that only network bits are set: {1}.".format(cidr, valid_cidr))
+ fixed_cidrs.append(valid_cidr)
+ else:
+ # let AWS handle invalid CIDRs
+ fixed_cidrs.append(cidr)
+ return fixed_cidrs
+
+
+def main():
+ argument_spec = dict(
+ name=dict(required=True),
+ cidr_block=dict(type='list', required=True),
+ ipv6_cidr=dict(type='bool', default=False),
+ tenancy=dict(choices=['default', 'dedicated'], default='default'),
+ dns_support=dict(type='bool', default=True),
+ dns_hostnames=dict(type='bool', default=True),
+ dhcp_opts_id=dict(),
+ tags=dict(type='dict', aliases=['resource_tags']),
+ state=dict(choices=['present', 'absent'], default='present'),
+ multi_ok=dict(type='bool', default=False),
+ purge_cidrs=dict(type='bool', default=False),
+ )
+
+ module = AnsibleAWSModule(
+ argument_spec=argument_spec,
+ supports_check_mode=True
+ )
+
+ name = module.params.get('name')
+ cidr_block = get_cidr_network_bits(module, module.params.get('cidr_block'))
+ ipv6_cidr = module.params.get('ipv6_cidr')
+ purge_cidrs = module.params.get('purge_cidrs')
+ tenancy = module.params.get('tenancy')
+ dns_support = module.params.get('dns_support')
+ dns_hostnames = module.params.get('dns_hostnames')
+ dhcp_id = module.params.get('dhcp_opts_id')
+ tags = module.params.get('tags')
+ state = module.params.get('state')
+ multi = module.params.get('multi_ok')
+
+ changed = False
+
+ connection = module.client(
+ 'ec2',
+ retry_decorator=AWSRetry.jittered_backoff(
+ retries=8, delay=3, catch_extra_error_codes=['InvalidVpcID.NotFound']
+ )
+ )
+
+ if dns_hostnames and not dns_support:
+ module.fail_json(msg='In order to enable DNS Hostnames you must also enable DNS support')
+
+ if state == 'present':
+
+ # Check if VPC exists
+ vpc_id = vpc_exists(module, connection, name, cidr_block, multi)
+
+ if vpc_id is None:
+ vpc_id = create_vpc(connection, module, cidr_block[0], tenancy)
+ changed = True
+
+ vpc_obj = get_vpc(module, connection, vpc_id)
+
+ associated_cidrs = dict((cidr['CidrBlock'], cidr['AssociationId']) for cidr in vpc_obj.get('CidrBlockAssociationSet', [])
+ if cidr['CidrBlockState']['State'] != 'disassociated')
+ to_add = [cidr for cidr in cidr_block if cidr not in associated_cidrs]
+ to_remove = [associated_cidrs[cidr] for cidr in associated_cidrs if cidr not in cidr_block]
+ expected_cidrs = [cidr for cidr in associated_cidrs if associated_cidrs[cidr] not in to_remove] + to_add
+
+ if len(cidr_block) > 1:
+ for cidr in to_add:
+ changed = True
+ try:
+ connection.associate_vpc_cidr_block(CidrBlock=cidr, VpcId=vpc_id)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "Unable to associate CIDR {0}.".format(ipv6_cidr))
+ if ipv6_cidr:
+ if 'Ipv6CidrBlockAssociationSet' in vpc_obj.keys():
+ module.warn("Only one IPv6 CIDR is permitted per VPC, {0} already has CIDR {1}".format(
+ vpc_id,
+ vpc_obj['Ipv6CidrBlockAssociationSet'][0]['Ipv6CidrBlock']))
+ else:
+ try:
+ connection.associate_vpc_cidr_block(AmazonProvidedIpv6CidrBlock=ipv6_cidr, VpcId=vpc_id)
+ changed = True
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "Unable to associate CIDR {0}.".format(ipv6_cidr))
+
+ if purge_cidrs:
+ for association_id in to_remove:
+ changed = True
+ try:
+ connection.disassociate_vpc_cidr_block(AssociationId=association_id)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "Unable to disassociate {0}. You must detach or delete all gateways and resources that "
+ "are associated with the CIDR block before you can disassociate it.".format(association_id))
+
+ if dhcp_id is not None:
+ try:
+ if update_dhcp_opts(connection, module, vpc_obj, dhcp_id):
+ changed = True
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "Failed to update DHCP options")
+
+ if tags is not None or name is not None:
+ try:
+ if update_vpc_tags(connection, module, vpc_id, tags, name):
+ changed = True
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to update tags")
+
+ current_dns_enabled = connection.describe_vpc_attribute(Attribute='enableDnsSupport', VpcId=vpc_id, aws_retry=True)['EnableDnsSupport']['Value']
+ current_dns_hostnames = connection.describe_vpc_attribute(Attribute='enableDnsHostnames', VpcId=vpc_id, aws_retry=True)['EnableDnsHostnames']['Value']
+ if current_dns_enabled != dns_support:
+ changed = True
+ if not module.check_mode:
+ try:
+ connection.modify_vpc_attribute(VpcId=vpc_id, EnableDnsSupport={'Value': dns_support})
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "Failed to update enabled dns support attribute")
+ if current_dns_hostnames != dns_hostnames:
+ changed = True
+ if not module.check_mode:
+ try:
+ connection.modify_vpc_attribute(VpcId=vpc_id, EnableDnsHostnames={'Value': dns_hostnames})
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "Failed to update enabled dns hostnames attribute")
+
+ # wait for associated cidrs to match
+ if to_add or to_remove:
+ try:
+ connection.get_waiter('vpc_available').wait(
+ VpcIds=[vpc_id],
+ Filters=[{'Name': 'cidr-block-association.cidr-block', 'Values': expected_cidrs}]
+ )
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "Failed to wait for CIDRs to update")
+
+ # try to wait for enableDnsSupport and enableDnsHostnames to match
+ wait_for_vpc_attribute(connection, module, vpc_id, 'enableDnsSupport', dns_support)
+ wait_for_vpc_attribute(connection, module, vpc_id, 'enableDnsHostnames', dns_hostnames)
+
+ final_state = camel_dict_to_snake_dict(get_vpc(module, connection, vpc_id))
+ final_state['tags'] = boto3_tag_list_to_ansible_dict(final_state.get('tags', []))
+ final_state['id'] = final_state.pop('vpc_id')
+
+ module.exit_json(changed=changed, vpc=final_state)
+
+ elif state == 'absent':
+
+ # Check if VPC exists
+ vpc_id = vpc_exists(module, connection, name, cidr_block, multi)
+
+ if vpc_id is not None:
+ try:
+ if not module.check_mode:
+ connection.delete_vpc(VpcId=vpc_id)
+ changed = True
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to delete VPC {0} You may want to use the ec2_vpc_subnet, ec2_vpc_igw, "
+ "and/or ec2_vpc_route_table modules to ensure the other components are absent.".format(vpc_id))
+
+ module.exit_json(changed=changed, vpc={})
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/ec2_vpc_subnet.py b/test/support/integration/plugins/modules/ec2_vpc_subnet.py
new file mode 100644
index 00000000..5085e99b
--- /dev/null
+++ b/test/support/integration/plugins/modules/ec2_vpc_subnet.py
@@ -0,0 +1,604 @@
+#!/usr/bin/python
+# Copyright: Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+
+DOCUMENTATION = '''
+---
+module: ec2_vpc_subnet
+short_description: Manage subnets in AWS virtual private clouds
+description:
+ - Manage subnets in AWS virtual private clouds.
+version_added: "2.0"
+author:
+- Robert Estelle (@erydo)
+- Brad Davidson (@brandond)
+requirements: [ boto3 ]
+options:
+ az:
+ description:
+ - "The availability zone for the subnet."
+ type: str
+ cidr:
+ description:
+ - "The CIDR block for the subnet. E.g. 192.0.2.0/24."
+ type: str
+ required: true
+ ipv6_cidr:
+ description:
+ - "The IPv6 CIDR block for the subnet. The VPC must have a /56 block assigned and this value must be a valid IPv6 /64 that falls in the VPC range."
+ - "Required if I(assign_instances_ipv6=true)"
+ version_added: "2.5"
+ type: str
+ tags:
+ description:
+ - "A dict of tags to apply to the subnet. Any tags currently applied to the subnet and not present here will be removed."
+ aliases: [ 'resource_tags' ]
+ type: dict
+ state:
+ description:
+ - "Create or remove the subnet."
+ default: present
+ choices: [ 'present', 'absent' ]
+ type: str
+ vpc_id:
+ description:
+ - "VPC ID of the VPC in which to create or delete the subnet."
+ required: true
+ type: str
+ map_public:
+ description:
+ - "Specify C(yes) to indicate that instances launched into the subnet should be assigned public IP address by default."
+ type: bool
+ default: 'no'
+ version_added: "2.4"
+ assign_instances_ipv6:
+ description:
+ - "Specify C(yes) to indicate that instances launched into the subnet should be automatically assigned an IPv6 address."
+ type: bool
+ default: false
+ version_added: "2.5"
+ wait:
+ description:
+ - "When I(wait=true) and I(state=present), module will wait for subnet to be in available state before continuing."
+ type: bool
+ default: true
+ version_added: "2.5"
+ wait_timeout:
+ description:
+ - "Number of seconds to wait for subnet to become available I(wait=True)."
+ default: 300
+ version_added: "2.5"
+ type: int
+ purge_tags:
+ description:
+ - Whether or not to remove tags that do not appear in the I(tags) list.
+ type: bool
+ default: true
+ version_added: "2.5"
+extends_documentation_fragment:
+ - aws
+ - ec2
+'''
+
+EXAMPLES = '''
+# Note: These examples do not set authentication details, see the AWS Guide for details.
+
+- name: Create subnet for database servers
+ ec2_vpc_subnet:
+ state: present
+ vpc_id: vpc-123456
+ cidr: 10.0.1.16/28
+ tags:
+ Name: Database Subnet
+ register: database_subnet
+
+- name: Remove subnet for database servers
+ ec2_vpc_subnet:
+ state: absent
+ vpc_id: vpc-123456
+ cidr: 10.0.1.16/28
+
+- name: Create subnet with IPv6 block assigned
+ ec2_vpc_subnet:
+ state: present
+ vpc_id: vpc-123456
+ cidr: 10.1.100.0/24
+ ipv6_cidr: 2001:db8:0:102::/64
+
+- name: Remove IPv6 block assigned to subnet
+ ec2_vpc_subnet:
+ state: present
+ vpc_id: vpc-123456
+ cidr: 10.1.100.0/24
+ ipv6_cidr: ''
+'''
+
+RETURN = '''
+subnet:
+ description: Dictionary of subnet values
+ returned: I(state=present)
+ type: complex
+ contains:
+ id:
+ description: Subnet resource id
+ returned: I(state=present)
+ type: str
+ sample: subnet-b883b2c4
+ cidr_block:
+ description: The IPv4 CIDR of the Subnet
+ returned: I(state=present)
+ type: str
+ sample: "10.0.0.0/16"
+ ipv6_cidr_block:
+ description: The IPv6 CIDR block actively associated with the Subnet
+ returned: I(state=present)
+ type: str
+ sample: "2001:db8:0:102::/64"
+ availability_zone:
+ description: Availability zone of the Subnet
+ returned: I(state=present)
+ type: str
+ sample: us-east-1a
+ state:
+ description: state of the Subnet
+ returned: I(state=present)
+ type: str
+ sample: available
+ tags:
+ description: tags attached to the Subnet, includes name
+ returned: I(state=present)
+ type: dict
+ sample: {"Name": "My Subnet", "env": "staging"}
+ map_public_ip_on_launch:
+ description: whether public IP is auto-assigned to new instances
+ returned: I(state=present)
+ type: bool
+ sample: false
+ assign_ipv6_address_on_creation:
+ description: whether IPv6 address is auto-assigned to new instances
+ returned: I(state=present)
+ type: bool
+ sample: false
+ vpc_id:
+ description: the id of the VPC where this Subnet exists
+ returned: I(state=present)
+ type: str
+ sample: vpc-67236184
+ available_ip_address_count:
+ description: number of available IPv4 addresses
+ returned: I(state=present)
+ type: str
+ sample: 251
+ default_for_az:
+ description: indicates whether this is the default Subnet for this Availability Zone
+ returned: I(state=present)
+ type: bool
+ sample: false
+ ipv6_association_id:
+ description: The IPv6 association ID for the currently associated CIDR
+ returned: I(state=present)
+ type: str
+ sample: subnet-cidr-assoc-b85c74d2
+ ipv6_cidr_block_association_set:
+ description: An array of IPv6 cidr block association set information.
+ returned: I(state=present)
+ type: complex
+ contains:
+ association_id:
+ description: The association ID
+ returned: always
+ type: str
+ ipv6_cidr_block:
+ description: The IPv6 CIDR block that is associated with the subnet.
+ returned: always
+ type: str
+ ipv6_cidr_block_state:
+ description: A hash/dict that contains a single item. The state of the cidr block association.
+ returned: always
+ type: dict
+ contains:
+ state:
+ description: The CIDR block association state.
+ returned: always
+ type: str
+'''
+
+
+import time
+
+try:
+ import botocore
+except ImportError:
+ pass # caught by AnsibleAWSModule
+
+from ansible.module_utils._text import to_text
+from ansible.module_utils.aws.core import AnsibleAWSModule
+from ansible.module_utils.aws.waiters import get_waiter
+from ansible.module_utils.ec2 import (ansible_dict_to_boto3_filter_list, ansible_dict_to_boto3_tag_list,
+ camel_dict_to_snake_dict, boto3_tag_list_to_ansible_dict, compare_aws_tags, AWSRetry)
+
+
+def get_subnet_info(subnet):
+ if 'Subnets' in subnet:
+ return [get_subnet_info(s) for s in subnet['Subnets']]
+ elif 'Subnet' in subnet:
+ subnet = camel_dict_to_snake_dict(subnet['Subnet'])
+ else:
+ subnet = camel_dict_to_snake_dict(subnet)
+
+ if 'tags' in subnet:
+ subnet['tags'] = boto3_tag_list_to_ansible_dict(subnet['tags'])
+ else:
+ subnet['tags'] = dict()
+
+ if 'subnet_id' in subnet:
+ subnet['id'] = subnet['subnet_id']
+ del subnet['subnet_id']
+
+ subnet['ipv6_cidr_block'] = ''
+ subnet['ipv6_association_id'] = ''
+ ipv6set = subnet.get('ipv6_cidr_block_association_set')
+ if ipv6set:
+ for item in ipv6set:
+ if item.get('ipv6_cidr_block_state', {}).get('state') in ('associated', 'associating'):
+ subnet['ipv6_cidr_block'] = item['ipv6_cidr_block']
+ subnet['ipv6_association_id'] = item['association_id']
+
+ return subnet
+
+
+@AWSRetry.exponential_backoff()
+def describe_subnets_with_backoff(client, **params):
+ return client.describe_subnets(**params)
+
+
+def waiter_params(module, params, start_time):
+ if not module.botocore_at_least("1.7.0"):
+ remaining_wait_timeout = int(module.params['wait_timeout'] + start_time - time.time())
+ params['WaiterConfig'] = {'Delay': 5, 'MaxAttempts': remaining_wait_timeout // 5}
+ return params
+
+
+def handle_waiter(conn, module, waiter_name, params, start_time):
+ try:
+ get_waiter(conn, waiter_name).wait(
+ **waiter_params(module, params, start_time)
+ )
+ except botocore.exceptions.WaiterError as e:
+ module.fail_json_aws(e, "Failed to wait for updates to complete")
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "An exception happened while trying to wait for updates")
+
+
+def create_subnet(conn, module, vpc_id, cidr, ipv6_cidr=None, az=None, start_time=None):
+ wait = module.params['wait']
+ wait_timeout = module.params['wait_timeout']
+
+ params = dict(VpcId=vpc_id,
+ CidrBlock=cidr)
+
+ if ipv6_cidr:
+ params['Ipv6CidrBlock'] = ipv6_cidr
+
+ if az:
+ params['AvailabilityZone'] = az
+
+ try:
+ subnet = get_subnet_info(conn.create_subnet(**params))
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't create subnet")
+
+ # Sometimes AWS takes its time to create a subnet and so using
+ # new subnets's id to do things like create tags results in
+ # exception.
+ if wait and subnet.get('state') != 'available':
+ handle_waiter(conn, module, 'subnet_exists', {'SubnetIds': [subnet['id']]}, start_time)
+ try:
+ conn.get_waiter('subnet_available').wait(
+ **waiter_params(module, {'SubnetIds': [subnet['id']]}, start_time)
+ )
+ subnet['state'] = 'available'
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, "Create subnet action timed out waiting for subnet to become available")
+
+ return subnet
+
+
+def ensure_tags(conn, module, subnet, tags, purge_tags, start_time):
+ changed = False
+
+ filters = ansible_dict_to_boto3_filter_list({'resource-id': subnet['id'], 'resource-type': 'subnet'})
+ try:
+ cur_tags = conn.describe_tags(Filters=filters)
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't describe tags")
+
+ to_update, to_delete = compare_aws_tags(boto3_tag_list_to_ansible_dict(cur_tags.get('Tags')), tags, purge_tags)
+
+ if to_update:
+ try:
+ if not module.check_mode:
+ AWSRetry.exponential_backoff(
+ catch_extra_error_codes=['InvalidSubnetID.NotFound']
+ )(conn.create_tags)(
+ Resources=[subnet['id']],
+ Tags=ansible_dict_to_boto3_tag_list(to_update)
+ )
+
+ changed = True
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't create tags")
+
+ if to_delete:
+ try:
+ if not module.check_mode:
+ tags_list = []
+ for key in to_delete:
+ tags_list.append({'Key': key})
+
+ AWSRetry.exponential_backoff(
+ catch_extra_error_codes=['InvalidSubnetID.NotFound']
+ )(conn.delete_tags)(Resources=[subnet['id']], Tags=tags_list)
+
+ changed = True
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't delete tags")
+
+ if module.params['wait'] and not module.check_mode:
+ # Wait for tags to be updated
+ filters = [{'Name': 'tag:{0}'.format(k), 'Values': [v]} for k, v in tags.items()]
+ handle_waiter(conn, module, 'subnet_exists',
+ {'SubnetIds': [subnet['id']], 'Filters': filters}, start_time)
+
+ return changed
+
+
+def ensure_map_public(conn, module, subnet, map_public, check_mode, start_time):
+ if check_mode:
+ return
+ try:
+ conn.modify_subnet_attribute(SubnetId=subnet['id'], MapPublicIpOnLaunch={'Value': map_public})
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't modify subnet attribute")
+
+
+def ensure_assign_ipv6_on_create(conn, module, subnet, assign_instances_ipv6, check_mode, start_time):
+ if check_mode:
+ return
+ try:
+ conn.modify_subnet_attribute(SubnetId=subnet['id'], AssignIpv6AddressOnCreation={'Value': assign_instances_ipv6})
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't modify subnet attribute")
+
+
+def disassociate_ipv6_cidr(conn, module, subnet, start_time):
+ if subnet.get('assign_ipv6_address_on_creation'):
+ ensure_assign_ipv6_on_create(conn, module, subnet, False, False, start_time)
+
+ try:
+ conn.disassociate_subnet_cidr_block(AssociationId=subnet['ipv6_association_id'])
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't disassociate ipv6 cidr block id {0} from subnet {1}"
+ .format(subnet['ipv6_association_id'], subnet['id']))
+
+ # Wait for cidr block to be disassociated
+ if module.params['wait']:
+ filters = ansible_dict_to_boto3_filter_list(
+ {'ipv6-cidr-block-association.state': ['disassociated'],
+ 'vpc-id': subnet['vpc_id']}
+ )
+ handle_waiter(conn, module, 'subnet_exists',
+ {'SubnetIds': [subnet['id']], 'Filters': filters}, start_time)
+
+
+def ensure_ipv6_cidr_block(conn, module, subnet, ipv6_cidr, check_mode, start_time):
+ wait = module.params['wait']
+ changed = False
+
+ if subnet['ipv6_association_id'] and not ipv6_cidr:
+ if not check_mode:
+ disassociate_ipv6_cidr(conn, module, subnet, start_time)
+ changed = True
+
+ if ipv6_cidr:
+ filters = ansible_dict_to_boto3_filter_list({'ipv6-cidr-block-association.ipv6-cidr-block': ipv6_cidr,
+ 'vpc-id': subnet['vpc_id']})
+
+ try:
+ check_subnets = get_subnet_info(describe_subnets_with_backoff(conn, Filters=filters))
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't get subnet info")
+
+ if check_subnets and check_subnets[0]['ipv6_cidr_block']:
+ module.fail_json(msg="The IPv6 CIDR '{0}' conflicts with another subnet".format(ipv6_cidr))
+
+ if subnet['ipv6_association_id']:
+ if not check_mode:
+ disassociate_ipv6_cidr(conn, module, subnet, start_time)
+ changed = True
+
+ try:
+ if not check_mode:
+ associate_resp = conn.associate_subnet_cidr_block(SubnetId=subnet['id'], Ipv6CidrBlock=ipv6_cidr)
+ changed = True
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't associate ipv6 cidr {0} to {1}".format(ipv6_cidr, subnet['id']))
+ else:
+ if not check_mode and wait:
+ filters = ansible_dict_to_boto3_filter_list(
+ {'ipv6-cidr-block-association.state': ['associated'],
+ 'vpc-id': subnet['vpc_id']}
+ )
+ handle_waiter(conn, module, 'subnet_exists',
+ {'SubnetIds': [subnet['id']], 'Filters': filters}, start_time)
+
+ if associate_resp.get('Ipv6CidrBlockAssociation', {}).get('AssociationId'):
+ subnet['ipv6_association_id'] = associate_resp['Ipv6CidrBlockAssociation']['AssociationId']
+ subnet['ipv6_cidr_block'] = associate_resp['Ipv6CidrBlockAssociation']['Ipv6CidrBlock']
+ if subnet['ipv6_cidr_block_association_set']:
+ subnet['ipv6_cidr_block_association_set'][0] = camel_dict_to_snake_dict(associate_resp['Ipv6CidrBlockAssociation'])
+ else:
+ subnet['ipv6_cidr_block_association_set'].append(camel_dict_to_snake_dict(associate_resp['Ipv6CidrBlockAssociation']))
+
+ return changed
+
+
+def get_matching_subnet(conn, module, vpc_id, cidr):
+ filters = ansible_dict_to_boto3_filter_list({'vpc-id': vpc_id, 'cidr-block': cidr})
+ try:
+ subnets = get_subnet_info(describe_subnets_with_backoff(conn, Filters=filters))
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't get matching subnet")
+
+ if subnets:
+ return subnets[0]
+
+ return None
+
+
+def ensure_subnet_present(conn, module):
+ subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr'])
+ changed = False
+
+ # Initialize start so max time does not exceed the specified wait_timeout for multiple operations
+ start_time = time.time()
+
+ if subnet is None:
+ if not module.check_mode:
+ subnet = create_subnet(conn, module, module.params['vpc_id'], module.params['cidr'],
+ ipv6_cidr=module.params['ipv6_cidr'], az=module.params['az'], start_time=start_time)
+ changed = True
+ # Subnet will be None when check_mode is true
+ if subnet is None:
+ return {
+ 'changed': changed,
+ 'subnet': {}
+ }
+ if module.params['wait']:
+ handle_waiter(conn, module, 'subnet_exists', {'SubnetIds': [subnet['id']]}, start_time)
+
+ if module.params['ipv6_cidr'] != subnet.get('ipv6_cidr_block'):
+ if ensure_ipv6_cidr_block(conn, module, subnet, module.params['ipv6_cidr'], module.check_mode, start_time):
+ changed = True
+
+ if module.params['map_public'] != subnet['map_public_ip_on_launch']:
+ ensure_map_public(conn, module, subnet, module.params['map_public'], module.check_mode, start_time)
+ changed = True
+
+ if module.params['assign_instances_ipv6'] != subnet.get('assign_ipv6_address_on_creation'):
+ ensure_assign_ipv6_on_create(conn, module, subnet, module.params['assign_instances_ipv6'], module.check_mode, start_time)
+ changed = True
+
+ if module.params['tags'] != subnet['tags']:
+ stringified_tags_dict = dict((to_text(k), to_text(v)) for k, v in module.params['tags'].items())
+ if ensure_tags(conn, module, subnet, stringified_tags_dict, module.params['purge_tags'], start_time):
+ changed = True
+
+ subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr'])
+ if not module.check_mode and module.params['wait']:
+ # GET calls are not monotonic for map_public_ip_on_launch and assign_ipv6_address_on_creation
+ # so we only wait for those if necessary just before returning the subnet
+ subnet = ensure_final_subnet(conn, module, subnet, start_time)
+
+ return {
+ 'changed': changed,
+ 'subnet': subnet
+ }
+
+
+def ensure_final_subnet(conn, module, subnet, start_time):
+ for rewait in range(0, 30):
+ map_public_correct = False
+ assign_ipv6_correct = False
+
+ if module.params['map_public'] == subnet['map_public_ip_on_launch']:
+ map_public_correct = True
+ else:
+ if module.params['map_public']:
+ handle_waiter(conn, module, 'subnet_has_map_public', {'SubnetIds': [subnet['id']]}, start_time)
+ else:
+ handle_waiter(conn, module, 'subnet_no_map_public', {'SubnetIds': [subnet['id']]}, start_time)
+
+ if module.params['assign_instances_ipv6'] == subnet.get('assign_ipv6_address_on_creation'):
+ assign_ipv6_correct = True
+ else:
+ if module.params['assign_instances_ipv6']:
+ handle_waiter(conn, module, 'subnet_has_assign_ipv6', {'SubnetIds': [subnet['id']]}, start_time)
+ else:
+ handle_waiter(conn, module, 'subnet_no_assign_ipv6', {'SubnetIds': [subnet['id']]}, start_time)
+
+ if map_public_correct and assign_ipv6_correct:
+ break
+
+ time.sleep(5)
+ subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr'])
+
+ return subnet
+
+
+def ensure_subnet_absent(conn, module):
+ subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr'])
+ if subnet is None:
+ return {'changed': False}
+
+ try:
+ if not module.check_mode:
+ conn.delete_subnet(SubnetId=subnet['id'])
+ if module.params['wait']:
+ handle_waiter(conn, module, 'subnet_deleted', {'SubnetIds': [subnet['id']]}, time.time())
+ return {'changed': True}
+ except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Couldn't delete subnet")
+
+
+def main():
+ argument_spec = dict(
+ az=dict(default=None, required=False),
+ cidr=dict(required=True),
+ ipv6_cidr=dict(default='', required=False),
+ state=dict(default='present', choices=['present', 'absent']),
+ tags=dict(default={}, required=False, type='dict', aliases=['resource_tags']),
+ vpc_id=dict(required=True),
+ map_public=dict(default=False, required=False, type='bool'),
+ assign_instances_ipv6=dict(default=False, required=False, type='bool'),
+ wait=dict(type='bool', default=True),
+ wait_timeout=dict(type='int', default=300, required=False),
+ purge_tags=dict(default=True, type='bool')
+ )
+
+ required_if = [('assign_instances_ipv6', True, ['ipv6_cidr'])]
+
+ module = AnsibleAWSModule(argument_spec=argument_spec, supports_check_mode=True, required_if=required_if)
+
+ if module.params.get('assign_instances_ipv6') and not module.params.get('ipv6_cidr'):
+ module.fail_json(msg="assign_instances_ipv6 is True but ipv6_cidr is None or an empty string")
+
+ if not module.botocore_at_least("1.7.0"):
+ module.warn("botocore >= 1.7.0 is required to use wait_timeout for custom wait times")
+
+ connection = module.client('ec2')
+
+ state = module.params.get('state')
+
+ try:
+ if state == 'present':
+ result = ensure_subnet_present(connection, module)
+ elif state == 'absent':
+ result = ensure_subnet_absent(connection, module)
+ except botocore.exceptions.ClientError as e:
+ module.fail_json_aws(e)
+
+ module.exit_json(**result)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/flatpak_remote.py b/test/support/integration/plugins/modules/flatpak_remote.py
new file mode 100644
index 00000000..db208f1b
--- /dev/null
+++ b/test/support/integration/plugins/modules/flatpak_remote.py
@@ -0,0 +1,243 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2017 John Kwiatkoski (@JayKayy) <jkwiat40@gmail.com>
+# Copyright: (c) 2018 Alexander Bethke (@oolongbrothers) <oolongbrothers@gmx.net>
+# Copyright: (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+
+# ATTENTION CONTRIBUTORS!
+#
+# TL;DR: Run this module's integration tests manually before opening a pull request
+#
+# Long explanation:
+# The integration tests for this module are currently NOT run on the Ansible project's continuous
+# delivery pipeline. So please: When you make changes to this module, make sure that you run the
+# included integration tests manually for both Python 2 and Python 3:
+#
+# Python 2:
+# ansible-test integration -v --docker fedora28 --docker-privileged --allow-unsupported --python 2.7 flatpak_remote
+# Python 3:
+# ansible-test integration -v --docker fedora28 --docker-privileged --allow-unsupported --python 3.6 flatpak_remote
+#
+# Because of external dependencies, the current integration tests are somewhat too slow and brittle
+# to be included right now. I have plans to rewrite the integration tests based on a local flatpak
+# repository so that they can be included into the normal CI pipeline.
+# //oolongbrothers
+
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: flatpak_remote
+version_added: '2.6'
+short_description: Manage flatpak repository remotes
+description:
+- Allows users to add or remove flatpak remotes.
+- The flatpak remotes concept is comparable to what is called repositories in other packaging
+ formats.
+- Currently, remote addition is only supported via I(flatpakrepo) file URLs.
+- Existing remotes will not be updated.
+- See the M(flatpak) module for managing flatpaks.
+author:
+- John Kwiatkoski (@JayKayy)
+- Alexander Bethke (@oolongbrothers)
+requirements:
+- flatpak
+options:
+ executable:
+ description:
+ - The path to the C(flatpak) executable to use.
+ - By default, this module looks for the C(flatpak) executable on the path.
+ default: flatpak
+ flatpakrepo_url:
+ description:
+ - The URL to the I(flatpakrepo) file representing the repository remote to add.
+ - When used with I(state=present), the flatpak remote specified under the I(flatpakrepo_url)
+ is added using the specified installation C(method).
+ - When used with I(state=absent), this is not required.
+ - Required when I(state=present).
+ method:
+ description:
+ - The installation method to use.
+ - Defines if the I(flatpak) is supposed to be installed globally for the whole C(system)
+ or only for the current C(user).
+ choices: [ system, user ]
+ default: system
+ name:
+ description:
+ - The desired name for the flatpak remote to be registered under on the managed host.
+ - When used with I(state=present), the remote will be added to the managed host under
+ the specified I(name).
+ - When used with I(state=absent) the remote with that name will be removed.
+ required: true
+ state:
+ description:
+ - Indicates the desired package state.
+ choices: [ absent, present ]
+ default: present
+'''
+
+EXAMPLES = r'''
+- name: Add the Gnome flatpak remote to the system installation
+ flatpak_remote:
+ name: gnome
+ state: present
+ flatpakrepo_url: https://sdk.gnome.org/gnome-apps.flatpakrepo
+
+- name: Add the flathub flatpak repository remote to the user installation
+ flatpak_remote:
+ name: flathub
+ state: present
+ flatpakrepo_url: https://dl.flathub.org/repo/flathub.flatpakrepo
+ method: user
+
+- name: Remove the Gnome flatpak remote from the user installation
+ flatpak_remote:
+ name: gnome
+ state: absent
+ method: user
+
+- name: Remove the flathub remote from the system installation
+ flatpak_remote:
+ name: flathub
+ state: absent
+'''
+
+RETURN = r'''
+command:
+ description: The exact flatpak command that was executed
+ returned: When a flatpak command has been executed
+ type: str
+ sample: "/usr/bin/flatpak remote-add --system flatpak-test https://dl.flathub.org/repo/flathub.flatpakrepo"
+msg:
+ description: Module error message
+ returned: failure
+ type: str
+ sample: "Executable '/usr/local/bin/flatpak' was not found on the system."
+rc:
+ description: Return code from flatpak binary
+ returned: When a flatpak command has been executed
+ type: int
+ sample: 0
+stderr:
+ description: Error output from flatpak binary
+ returned: When a flatpak command has been executed
+ type: str
+ sample: "error: GPG verification enabled, but no summary found (check that the configured URL in remote config is correct)\n"
+stdout:
+ description: Output from flatpak binary
+ returned: When a flatpak command has been executed
+ type: str
+ sample: "flathub\tFlathub\thttps://dl.flathub.org/repo/\t1\t\n"
+'''
+
+import subprocess
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils._text import to_bytes, to_native
+
+
+def add_remote(module, binary, name, flatpakrepo_url, method):
+ """Add a new remote."""
+ global result
+ command = "{0} remote-add --{1} {2} {3}".format(
+ binary, method, name, flatpakrepo_url)
+ _flatpak_command(module, module.check_mode, command)
+ result['changed'] = True
+
+
+def remove_remote(module, binary, name, method):
+ """Remove an existing remote."""
+ global result
+ command = "{0} remote-delete --{1} --force {2} ".format(
+ binary, method, name)
+ _flatpak_command(module, module.check_mode, command)
+ result['changed'] = True
+
+
+def remote_exists(module, binary, name, method):
+ """Check if the remote exists."""
+ command = "{0} remote-list -d --{1}".format(binary, method)
+ # The query operation for the remote needs to be run even in check mode
+ output = _flatpak_command(module, False, command)
+ for line in output.splitlines():
+ listed_remote = line.split()
+ if len(listed_remote) == 0:
+ continue
+ if listed_remote[0] == to_native(name):
+ return True
+ return False
+
+
+def _flatpak_command(module, noop, command):
+ global result
+ if noop:
+ result['rc'] = 0
+ result['command'] = command
+ return ""
+
+ process = subprocess.Popen(
+ command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout_data, stderr_data = process.communicate()
+ result['rc'] = process.returncode
+ result['command'] = command
+ result['stdout'] = stdout_data
+ result['stderr'] = stderr_data
+ if result['rc'] != 0:
+ module.fail_json(msg="Failed to execute flatpak command", **result)
+ return to_native(stdout_data)
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ name=dict(type='str', required=True),
+ flatpakrepo_url=dict(type='str'),
+ method=dict(type='str', default='system',
+ choices=['user', 'system']),
+ state=dict(type='str', default="present",
+ choices=['absent', 'present']),
+ executable=dict(type='str', default="flatpak")
+ ),
+ # This module supports check mode
+ supports_check_mode=True,
+ )
+
+ name = module.params['name']
+ flatpakrepo_url = module.params['flatpakrepo_url']
+ method = module.params['method']
+ state = module.params['state']
+ executable = module.params['executable']
+ binary = module.get_bin_path(executable, None)
+
+ if flatpakrepo_url is None:
+ flatpakrepo_url = ''
+
+ global result
+ result = dict(
+ changed=False
+ )
+
+ # If the binary was not found, fail the operation
+ if not binary:
+ module.fail_json(msg="Executable '%s' was not found on the system." % executable, **result)
+
+ remote_already_exists = remote_exists(module, binary, to_bytes(name), method)
+
+ if state == 'present' and not remote_already_exists:
+ add_remote(module, binary, name, flatpakrepo_url, method)
+ elif state == 'absent' and remote_already_exists:
+ remove_remote(module, binary, name, method)
+
+ module.exit_json(**result)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/htpasswd.py b/test/support/integration/plugins/modules/htpasswd.py
new file mode 100644
index 00000000..ad12b0c0
--- /dev/null
+++ b/test/support/integration/plugins/modules/htpasswd.py
@@ -0,0 +1,275 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# (c) 2013, Nimbis Services, Inc.
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = """
+module: htpasswd
+version_added: "1.3"
+short_description: manage user files for basic authentication
+description:
+ - Add and remove username/password entries in a password file using htpasswd.
+ - This is used by web servers such as Apache and Nginx for basic authentication.
+options:
+ path:
+ required: true
+ aliases: [ dest, destfile ]
+ description:
+ - Path to the file that contains the usernames and passwords
+ name:
+ required: true
+ aliases: [ username ]
+ description:
+ - User name to add or remove
+ password:
+ required: false
+ description:
+ - Password associated with user.
+ - Must be specified if user does not exist yet.
+ crypt_scheme:
+ required: false
+ choices: ["apr_md5_crypt", "des_crypt", "ldap_sha1", "plaintext"]
+ default: "apr_md5_crypt"
+ description:
+ - Encryption scheme to be used. As well as the four choices listed
+ here, you can also use any other hash supported by passlib, such as
+ md5_crypt and sha256_crypt, which are linux passwd hashes. If you
+ do so the password file will not be compatible with Apache or Nginx
+ state:
+ required: false
+ choices: [ present, absent ]
+ default: "present"
+ description:
+ - Whether the user entry should be present or not
+ create:
+ required: false
+ type: bool
+ default: "yes"
+ description:
+ - Used with C(state=present). If specified, the file will be created
+ if it does not already exist. If set to "no", will fail if the
+ file does not exist
+notes:
+ - "This module depends on the I(passlib) Python library, which needs to be installed on all target systems."
+ - "On Debian, Ubuntu, or Fedora: install I(python-passlib)."
+ - "On RHEL or CentOS: Enable EPEL, then install I(python-passlib)."
+requirements: [ passlib>=1.6 ]
+author: "Ansible Core Team"
+extends_documentation_fragment: files
+"""
+
+EXAMPLES = """
+# Add a user to a password file and ensure permissions are set
+- htpasswd:
+ path: /etc/nginx/passwdfile
+ name: janedoe
+ password: '9s36?;fyNp'
+ owner: root
+ group: www-data
+ mode: 0640
+
+# Remove a user from a password file
+- htpasswd:
+ path: /etc/apache2/passwdfile
+ name: foobar
+ state: absent
+
+# Add a user to a password file suitable for use by libpam-pwdfile
+- htpasswd:
+ path: /etc/mail/passwords
+ name: alex
+ password: oedu2eGh
+ crypt_scheme: md5_crypt
+"""
+
+
+import os
+import tempfile
+import traceback
+from distutils.version import LooseVersion
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+from ansible.module_utils._text import to_native
+
+PASSLIB_IMP_ERR = None
+try:
+ from passlib.apache import HtpasswdFile, htpasswd_context
+ from passlib.context import CryptContext
+ import passlib
+except ImportError:
+ PASSLIB_IMP_ERR = traceback.format_exc()
+ passlib_installed = False
+else:
+ passlib_installed = True
+
+apache_hashes = ["apr_md5_crypt", "des_crypt", "ldap_sha1", "plaintext"]
+
+
+def create_missing_directories(dest):
+ destpath = os.path.dirname(dest)
+ if not os.path.exists(destpath):
+ os.makedirs(destpath)
+
+
+def present(dest, username, password, crypt_scheme, create, check_mode):
+ """ Ensures user is present
+
+ Returns (msg, changed) """
+ if crypt_scheme in apache_hashes:
+ context = htpasswd_context
+ else:
+ context = CryptContext(schemes=[crypt_scheme] + apache_hashes)
+ if not os.path.exists(dest):
+ if not create:
+ raise ValueError('Destination %s does not exist' % dest)
+ if check_mode:
+ return ("Create %s" % dest, True)
+ create_missing_directories(dest)
+ if LooseVersion(passlib.__version__) >= LooseVersion('1.6'):
+ ht = HtpasswdFile(dest, new=True, default_scheme=crypt_scheme, context=context)
+ else:
+ ht = HtpasswdFile(dest, autoload=False, default=crypt_scheme, context=context)
+ if getattr(ht, 'set_password', None):
+ ht.set_password(username, password)
+ else:
+ ht.update(username, password)
+ ht.save()
+ return ("Created %s and added %s" % (dest, username), True)
+ else:
+ if LooseVersion(passlib.__version__) >= LooseVersion('1.6'):
+ ht = HtpasswdFile(dest, new=False, default_scheme=crypt_scheme, context=context)
+ else:
+ ht = HtpasswdFile(dest, default=crypt_scheme, context=context)
+
+ found = None
+ if getattr(ht, 'check_password', None):
+ found = ht.check_password(username, password)
+ else:
+ found = ht.verify(username, password)
+
+ if found:
+ return ("%s already present" % username, False)
+ else:
+ if not check_mode:
+ if getattr(ht, 'set_password', None):
+ ht.set_password(username, password)
+ else:
+ ht.update(username, password)
+ ht.save()
+ return ("Add/update %s" % username, True)
+
+
+def absent(dest, username, check_mode):
+ """ Ensures user is absent
+
+ Returns (msg, changed) """
+ if LooseVersion(passlib.__version__) >= LooseVersion('1.6'):
+ ht = HtpasswdFile(dest, new=False)
+ else:
+ ht = HtpasswdFile(dest)
+
+ if username not in ht.users():
+ return ("%s not present" % username, False)
+ else:
+ if not check_mode:
+ ht.delete(username)
+ ht.save()
+ return ("Remove %s" % username, True)
+
+
+def check_file_attrs(module, changed, message):
+
+ file_args = module.load_file_common_arguments(module.params)
+ if module.set_fs_attributes_if_different(file_args, False):
+
+ if changed:
+ message += " and "
+ changed = True
+ message += "ownership, perms or SE linux context changed"
+
+ return message, changed
+
+
+def main():
+ arg_spec = dict(
+ path=dict(required=True, aliases=["dest", "destfile"]),
+ name=dict(required=True, aliases=["username"]),
+ password=dict(required=False, default=None, no_log=True),
+ crypt_scheme=dict(required=False, default="apr_md5_crypt"),
+ state=dict(required=False, default="present"),
+ create=dict(type='bool', default='yes'),
+
+ )
+ module = AnsibleModule(argument_spec=arg_spec,
+ add_file_common_args=True,
+ supports_check_mode=True)
+
+ path = module.params['path']
+ username = module.params['name']
+ password = module.params['password']
+ crypt_scheme = module.params['crypt_scheme']
+ state = module.params['state']
+ create = module.params['create']
+ check_mode = module.check_mode
+
+ if not passlib_installed:
+ module.fail_json(msg=missing_required_lib("passlib"), exception=PASSLIB_IMP_ERR)
+
+ # Check file for blank lines in effort to avoid "need more than 1 value to unpack" error.
+ try:
+ f = open(path, "r")
+ except IOError:
+ # No preexisting file to remove blank lines from
+ f = None
+ else:
+ try:
+ lines = f.readlines()
+ finally:
+ f.close()
+
+ # If the file gets edited, it returns true, so only edit the file if it has blank lines
+ strip = False
+ for line in lines:
+ if not line.strip():
+ strip = True
+ break
+
+ if strip:
+ # If check mode, create a temporary file
+ if check_mode:
+ temp = tempfile.NamedTemporaryFile()
+ path = temp.name
+ f = open(path, "w")
+ try:
+ [f.write(line) for line in lines if line.strip()]
+ finally:
+ f.close()
+
+ try:
+ if state == 'present':
+ (msg, changed) = present(path, username, password, crypt_scheme, create, check_mode)
+ elif state == 'absent':
+ if not os.path.exists(path):
+ module.exit_json(msg="%s not present" % username,
+ warnings="%s does not exist" % path, changed=False)
+ (msg, changed) = absent(path, username, check_mode)
+ else:
+ module.fail_json(msg="Invalid state: %s" % state)
+
+ check_file_attrs(module, changed, msg)
+ module.exit_json(msg=msg, changed=changed)
+ except Exception as e:
+ module.fail_json(msg=to_native(e))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/locale_gen.py b/test/support/integration/plugins/modules/locale_gen.py
new file mode 100644
index 00000000..4968b834
--- /dev/null
+++ b/test/support/integration/plugins/modules/locale_gen.py
@@ -0,0 +1,237 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = '''
+---
+module: locale_gen
+short_description: Creates or removes locales
+description:
+ - Manages locales by editing /etc/locale.gen and invoking locale-gen.
+version_added: "1.6"
+author:
+- Augustus Kling (@AugustusKling)
+options:
+ name:
+ description:
+ - Name and encoding of the locale, such as "en_GB.UTF-8".
+ required: true
+ state:
+ description:
+ - Whether the locale shall be present.
+ choices: [ absent, present ]
+ default: present
+'''
+
+EXAMPLES = '''
+- name: Ensure a locale exists
+ locale_gen:
+ name: de_CH.UTF-8
+ state: present
+'''
+
+import os
+import re
+from subprocess import Popen, PIPE, call
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils._text import to_native
+
+LOCALE_NORMALIZATION = {
+ ".utf8": ".UTF-8",
+ ".eucjp": ".EUC-JP",
+ ".iso885915": ".ISO-8859-15",
+ ".cp1251": ".CP1251",
+ ".koi8r": ".KOI8-R",
+ ".armscii8": ".ARMSCII-8",
+ ".euckr": ".EUC-KR",
+ ".gbk": ".GBK",
+ ".gb18030": ".GB18030",
+ ".euctw": ".EUC-TW",
+}
+
+
+# ===========================================
+# location module specific support methods.
+#
+
+def is_available(name, ubuntuMode):
+ """Check if the given locale is available on the system. This is done by
+ checking either :
+ * if the locale is present in /etc/locales.gen
+ * or if the locale is present in /usr/share/i18n/SUPPORTED"""
+ if ubuntuMode:
+ __regexp = r'^(?P<locale>\S+_\S+) (?P<charset>\S+)\s*$'
+ __locales_available = '/usr/share/i18n/SUPPORTED'
+ else:
+ __regexp = r'^#{0,1}\s*(?P<locale>\S+_\S+) (?P<charset>\S+)\s*$'
+ __locales_available = '/etc/locale.gen'
+
+ re_compiled = re.compile(__regexp)
+ fd = open(__locales_available, 'r')
+ for line in fd:
+ result = re_compiled.match(line)
+ if result and result.group('locale') == name:
+ return True
+ fd.close()
+ return False
+
+
+def is_present(name):
+ """Checks if the given locale is currently installed."""
+ output = Popen(["locale", "-a"], stdout=PIPE).communicate()[0]
+ output = to_native(output)
+ return any(fix_case(name) == fix_case(line) for line in output.splitlines())
+
+
+def fix_case(name):
+ """locale -a might return the encoding in either lower or upper case.
+ Passing through this function makes them uniform for comparisons."""
+ for s, r in LOCALE_NORMALIZATION.items():
+ name = name.replace(s, r)
+ return name
+
+
+def replace_line(existing_line, new_line):
+ """Replaces lines in /etc/locale.gen"""
+ try:
+ f = open("/etc/locale.gen", "r")
+ lines = [line.replace(existing_line, new_line) for line in f]
+ finally:
+ f.close()
+ try:
+ f = open("/etc/locale.gen", "w")
+ f.write("".join(lines))
+ finally:
+ f.close()
+
+
+def set_locale(name, enabled=True):
+ """ Sets the state of the locale. Defaults to enabled. """
+ search_string = r'#{0,1}\s*%s (?P<charset>.+)' % name
+ if enabled:
+ new_string = r'%s \g<charset>' % (name)
+ else:
+ new_string = r'# %s \g<charset>' % (name)
+ try:
+ f = open("/etc/locale.gen", "r")
+ lines = [re.sub(search_string, new_string, line) for line in f]
+ finally:
+ f.close()
+ try:
+ f = open("/etc/locale.gen", "w")
+ f.write("".join(lines))
+ finally:
+ f.close()
+
+
+def apply_change(targetState, name):
+ """Create or remove locale.
+
+ Keyword arguments:
+ targetState -- Desired state, either present or absent.
+ name -- Name including encoding such as de_CH.UTF-8.
+ """
+ if targetState == "present":
+ # Create locale.
+ set_locale(name, enabled=True)
+ else:
+ # Delete locale.
+ set_locale(name, enabled=False)
+
+ localeGenExitValue = call("locale-gen")
+ if localeGenExitValue != 0:
+ raise EnvironmentError(localeGenExitValue, "locale.gen failed to execute, it returned " + str(localeGenExitValue))
+
+
+def apply_change_ubuntu(targetState, name):
+ """Create or remove locale.
+
+ Keyword arguments:
+ targetState -- Desired state, either present or absent.
+ name -- Name including encoding such as de_CH.UTF-8.
+ """
+ if targetState == "present":
+ # Create locale.
+ # Ubuntu's patched locale-gen automatically adds the new locale to /var/lib/locales/supported.d/local
+ localeGenExitValue = call(["locale-gen", name])
+ else:
+ # Delete locale involves discarding the locale from /var/lib/locales/supported.d/local and regenerating all locales.
+ try:
+ f = open("/var/lib/locales/supported.d/local", "r")
+ content = f.readlines()
+ finally:
+ f.close()
+ try:
+ f = open("/var/lib/locales/supported.d/local", "w")
+ for line in content:
+ locale, charset = line.split(' ')
+ if locale != name:
+ f.write(line)
+ finally:
+ f.close()
+ # Purge locales and regenerate.
+ # Please provide a patch if you know how to avoid regenerating the locales to keep!
+ localeGenExitValue = call(["locale-gen", "--purge"])
+
+ if localeGenExitValue != 0:
+ raise EnvironmentError(localeGenExitValue, "locale.gen failed to execute, it returned " + str(localeGenExitValue))
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ name=dict(type='str', required=True),
+ state=dict(type='str', default='present', choices=['absent', 'present']),
+ ),
+ supports_check_mode=True,
+ )
+
+ name = module.params['name']
+ state = module.params['state']
+
+ if not os.path.exists("/etc/locale.gen"):
+ if os.path.exists("/var/lib/locales/supported.d/"):
+ # Ubuntu created its own system to manage locales.
+ ubuntuMode = True
+ else:
+ module.fail_json(msg="/etc/locale.gen and /var/lib/locales/supported.d/local are missing. Is the package \"locales\" installed?")
+ else:
+ # We found the common way to manage locales.
+ ubuntuMode = False
+
+ if not is_available(name, ubuntuMode):
+ module.fail_json(msg="The locale you've entered is not available "
+ "on your system.")
+
+ if is_present(name):
+ prev_state = "present"
+ else:
+ prev_state = "absent"
+ changed = (prev_state != state)
+
+ if module.check_mode:
+ module.exit_json(changed=changed)
+ else:
+ if changed:
+ try:
+ if ubuntuMode is False:
+ apply_change(state, name)
+ else:
+ apply_change_ubuntu(state, name)
+ except EnvironmentError as e:
+ module.fail_json(msg=to_native(e), exitValue=e.errno)
+
+ module.exit_json(name=name, changed=changed, msg="OK")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/lvg.py b/test/support/integration/plugins/modules/lvg.py
new file mode 100644
index 00000000..e2035f68
--- /dev/null
+++ b/test/support/integration/plugins/modules/lvg.py
@@ -0,0 +1,295 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2013, Alexander Bulimov <lazywolf0@gmail.com>
+# Based on lvol module by Jeroen Hoekx <jeroen.hoekx@dsquare.be>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+author:
+- Alexander Bulimov (@abulimov)
+module: lvg
+short_description: Configure LVM volume groups
+description:
+ - This module creates, removes or resizes volume groups.
+version_added: "1.1"
+options:
+ vg:
+ description:
+ - The name of the volume group.
+ type: str
+ required: true
+ pvs:
+ description:
+ - List of comma-separated devices to use as physical devices in this volume group.
+ - Required when creating or resizing volume group.
+ - The module will take care of running pvcreate if needed.
+ type: list
+ pesize:
+ description:
+ - "The size of the physical extent. I(pesize) must be a power of 2 of at least 1 sector
+ (where the sector size is the largest sector size of the PVs currently used in the VG),
+ or at least 128KiB."
+ - Since Ansible 2.6, pesize can be optionally suffixed by a UNIT (k/K/m/M/g/G), default unit is megabyte.
+ type: str
+ default: "4"
+ pv_options:
+ description:
+ - Additional options to pass to C(pvcreate) when creating the volume group.
+ type: str
+ version_added: "2.4"
+ vg_options:
+ description:
+ - Additional options to pass to C(vgcreate) when creating the volume group.
+ type: str
+ version_added: "1.6"
+ state:
+ description:
+ - Control if the volume group exists.
+ type: str
+ choices: [ absent, present ]
+ default: present
+ force:
+ description:
+ - If C(yes), allows to remove volume group with logical volumes.
+ type: bool
+ default: no
+seealso:
+- module: filesystem
+- module: lvol
+- module: parted
+notes:
+ - This module does not modify PE size for already present volume group.
+'''
+
+EXAMPLES = r'''
+- name: Create a volume group on top of /dev/sda1 with physical extent size = 32MB
+ lvg:
+ vg: vg.services
+ pvs: /dev/sda1
+ pesize: 32
+
+- name: Create a volume group on top of /dev/sdb with physical extent size = 128KiB
+ lvg:
+ vg: vg.services
+ pvs: /dev/sdb
+ pesize: 128K
+
+# If, for example, we already have VG vg.services on top of /dev/sdb1,
+# this VG will be extended by /dev/sdc5. Or if vg.services was created on
+# top of /dev/sda5, we first extend it with /dev/sdb1 and /dev/sdc5,
+# and then reduce by /dev/sda5.
+- name: Create or resize a volume group on top of /dev/sdb1 and /dev/sdc5.
+ lvg:
+ vg: vg.services
+ pvs: /dev/sdb1,/dev/sdc5
+
+- name: Remove a volume group with name vg.services
+ lvg:
+ vg: vg.services
+ state: absent
+'''
+
+import itertools
+import os
+
+from ansible.module_utils.basic import AnsibleModule
+
+
+def parse_vgs(data):
+ vgs = []
+ for line in data.splitlines():
+ parts = line.strip().split(';')
+ vgs.append({
+ 'name': parts[0],
+ 'pv_count': int(parts[1]),
+ 'lv_count': int(parts[2]),
+ })
+ return vgs
+
+
+def find_mapper_device_name(module, dm_device):
+ dmsetup_cmd = module.get_bin_path('dmsetup', True)
+ mapper_prefix = '/dev/mapper/'
+ rc, dm_name, err = module.run_command("%s info -C --noheadings -o name %s" % (dmsetup_cmd, dm_device))
+ if rc != 0:
+ module.fail_json(msg="Failed executing dmsetup command.", rc=rc, err=err)
+ mapper_device = mapper_prefix + dm_name.rstrip()
+ return mapper_device
+
+
+def parse_pvs(module, data):
+ pvs = []
+ dm_prefix = '/dev/dm-'
+ for line in data.splitlines():
+ parts = line.strip().split(';')
+ if parts[0].startswith(dm_prefix):
+ parts[0] = find_mapper_device_name(module, parts[0])
+ pvs.append({
+ 'name': parts[0],
+ 'vg_name': parts[1],
+ })
+ return pvs
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ vg=dict(type='str', required=True),
+ pvs=dict(type='list'),
+ pesize=dict(type='str', default='4'),
+ pv_options=dict(type='str', default=''),
+ vg_options=dict(type='str', default=''),
+ state=dict(type='str', default='present', choices=['absent', 'present']),
+ force=dict(type='bool', default=False),
+ ),
+ supports_check_mode=True,
+ )
+
+ vg = module.params['vg']
+ state = module.params['state']
+ force = module.boolean(module.params['force'])
+ pesize = module.params['pesize']
+ pvoptions = module.params['pv_options'].split()
+ vgoptions = module.params['vg_options'].split()
+
+ dev_list = []
+ if module.params['pvs']:
+ dev_list = list(module.params['pvs'])
+ elif state == 'present':
+ module.fail_json(msg="No physical volumes given.")
+
+ # LVM always uses real paths not symlinks so replace symlinks with actual path
+ for idx, dev in enumerate(dev_list):
+ dev_list[idx] = os.path.realpath(dev)
+
+ if state == 'present':
+ # check given devices
+ for test_dev in dev_list:
+ if not os.path.exists(test_dev):
+ module.fail_json(msg="Device %s not found." % test_dev)
+
+ # get pv list
+ pvs_cmd = module.get_bin_path('pvs', True)
+ if dev_list:
+ pvs_filter_pv_name = ' || '.join(
+ 'pv_name = {0}'.format(x)
+ for x in itertools.chain(dev_list, module.params['pvs'])
+ )
+ pvs_filter_vg_name = 'vg_name = {0}'.format(vg)
+ pvs_filter = "--select '{0} || {1}' ".format(pvs_filter_pv_name, pvs_filter_vg_name)
+ else:
+ pvs_filter = ''
+ rc, current_pvs, err = module.run_command("%s --noheadings -o pv_name,vg_name --separator ';' %s" % (pvs_cmd, pvs_filter))
+ if rc != 0:
+ module.fail_json(msg="Failed executing pvs command.", rc=rc, err=err)
+
+ # check pv for devices
+ pvs = parse_pvs(module, current_pvs)
+ used_pvs = [pv for pv in pvs if pv['name'] in dev_list and pv['vg_name'] and pv['vg_name'] != vg]
+ if used_pvs:
+ module.fail_json(msg="Device %s is already in %s volume group." % (used_pvs[0]['name'], used_pvs[0]['vg_name']))
+
+ vgs_cmd = module.get_bin_path('vgs', True)
+ rc, current_vgs, err = module.run_command("%s --noheadings -o vg_name,pv_count,lv_count --separator ';'" % vgs_cmd)
+
+ if rc != 0:
+ module.fail_json(msg="Failed executing vgs command.", rc=rc, err=err)
+
+ changed = False
+
+ vgs = parse_vgs(current_vgs)
+
+ for test_vg in vgs:
+ if test_vg['name'] == vg:
+ this_vg = test_vg
+ break
+ else:
+ this_vg = None
+
+ if this_vg is None:
+ if state == 'present':
+ # create VG
+ if module.check_mode:
+ changed = True
+ else:
+ # create PV
+ pvcreate_cmd = module.get_bin_path('pvcreate', True)
+ for current_dev in dev_list:
+ rc, _, err = module.run_command([pvcreate_cmd] + pvoptions + ['-f', str(current_dev)])
+ if rc == 0:
+ changed = True
+ else:
+ module.fail_json(msg="Creating physical volume '%s' failed" % current_dev, rc=rc, err=err)
+ vgcreate_cmd = module.get_bin_path('vgcreate')
+ rc, _, err = module.run_command([vgcreate_cmd] + vgoptions + ['-s', pesize, vg] + dev_list)
+ if rc == 0:
+ changed = True
+ else:
+ module.fail_json(msg="Creating volume group '%s' failed" % vg, rc=rc, err=err)
+ else:
+ if state == 'absent':
+ if module.check_mode:
+ module.exit_json(changed=True)
+ else:
+ if this_vg['lv_count'] == 0 or force:
+ # remove VG
+ vgremove_cmd = module.get_bin_path('vgremove', True)
+ rc, _, err = module.run_command("%s --force %s" % (vgremove_cmd, vg))
+ if rc == 0:
+ module.exit_json(changed=True)
+ else:
+ module.fail_json(msg="Failed to remove volume group %s" % (vg), rc=rc, err=err)
+ else:
+ module.fail_json(msg="Refuse to remove non-empty volume group %s without force=yes" % (vg))
+
+ # resize VG
+ current_devs = [os.path.realpath(pv['name']) for pv in pvs if pv['vg_name'] == vg]
+ devs_to_remove = list(set(current_devs) - set(dev_list))
+ devs_to_add = list(set(dev_list) - set(current_devs))
+
+ if devs_to_add or devs_to_remove:
+ if module.check_mode:
+ changed = True
+ else:
+ if devs_to_add:
+ devs_to_add_string = ' '.join(devs_to_add)
+ # create PV
+ pvcreate_cmd = module.get_bin_path('pvcreate', True)
+ for current_dev in devs_to_add:
+ rc, _, err = module.run_command([pvcreate_cmd] + pvoptions + ['-f', str(current_dev)])
+ if rc == 0:
+ changed = True
+ else:
+ module.fail_json(msg="Creating physical volume '%s' failed" % current_dev, rc=rc, err=err)
+ # add PV to our VG
+ vgextend_cmd = module.get_bin_path('vgextend', True)
+ rc, _, err = module.run_command("%s %s %s" % (vgextend_cmd, vg, devs_to_add_string))
+ if rc == 0:
+ changed = True
+ else:
+ module.fail_json(msg="Unable to extend %s by %s." % (vg, devs_to_add_string), rc=rc, err=err)
+
+ # remove some PV from our VG
+ if devs_to_remove:
+ devs_to_remove_string = ' '.join(devs_to_remove)
+ vgreduce_cmd = module.get_bin_path('vgreduce', True)
+ rc, _, err = module.run_command("%s --force %s %s" % (vgreduce_cmd, vg, devs_to_remove_string))
+ if rc == 0:
+ changed = True
+ else:
+ module.fail_json(msg="Unable to reduce %s by %s." % (vg, devs_to_remove_string), rc=rc, err=err)
+
+ module.exit_json(changed=changed)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/mongodb_parameter.py b/test/support/integration/plugins/modules/mongodb_parameter.py
new file mode 100644
index 00000000..05de42b2
--- /dev/null
+++ b/test/support/integration/plugins/modules/mongodb_parameter.py
@@ -0,0 +1,223 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# (c) 2016, Loic Blot <loic.blot@unix-experience.fr>
+# Sponsored by Infopro Digital. http://www.infopro-digital.com/
+# Sponsored by E.T.A.I. http://www.etai.fr/
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = r'''
+---
+module: mongodb_parameter
+short_description: Change an administrative parameter on a MongoDB server
+description:
+ - Change an administrative parameter on a MongoDB server.
+version_added: "2.1"
+options:
+ login_user:
+ description:
+ - The MongoDB username used to authenticate with.
+ type: str
+ login_password:
+ description:
+ - The login user's password used to authenticate with.
+ type: str
+ login_host:
+ description:
+ - The host running the database.
+ type: str
+ default: localhost
+ login_port:
+ description:
+ - The MongoDB port to connect to.
+ default: 27017
+ type: int
+ login_database:
+ description:
+ - The database where login credentials are stored.
+ type: str
+ replica_set:
+ description:
+ - Replica set to connect to (automatically connects to primary for writes).
+ type: str
+ ssl:
+ description:
+ - Whether to use an SSL connection when connecting to the database.
+ type: bool
+ default: no
+ param:
+ description:
+ - MongoDB administrative parameter to modify.
+ type: str
+ required: true
+ value:
+ description:
+ - MongoDB administrative parameter value to set.
+ type: str
+ required: true
+ param_type:
+ description:
+ - Define the type of parameter value.
+ default: str
+ type: str
+ choices: [int, str]
+
+notes:
+ - Requires the pymongo Python package on the remote host, version 2.4.2+.
+ - This can be installed using pip or the OS package manager.
+ - See also U(http://api.mongodb.org/python/current/installation.html)
+requirements: [ "pymongo" ]
+author: "Loic Blot (@nerzhul)"
+'''
+
+EXAMPLES = r'''
+- name: Set MongoDB syncdelay to 60 (this is an int)
+ mongodb_parameter:
+ param: syncdelay
+ value: 60
+ param_type: int
+'''
+
+RETURN = r'''
+before:
+ description: value before modification
+ returned: success
+ type: str
+after:
+ description: value after modification
+ returned: success
+ type: str
+'''
+
+import os
+import traceback
+
+try:
+ from pymongo.errors import ConnectionFailure
+ from pymongo.errors import OperationFailure
+ from pymongo import version as PyMongoVersion
+ from pymongo import MongoClient
+except ImportError:
+ try: # for older PyMongo 2.2
+ from pymongo import Connection as MongoClient
+ except ImportError:
+ pymongo_found = False
+ else:
+ pymongo_found = True
+else:
+ pymongo_found = True
+
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+from ansible.module_utils.six.moves import configparser
+from ansible.module_utils._text import to_native
+
+
+# =========================================
+# MongoDB module specific support methods.
+#
+
+def load_mongocnf():
+ config = configparser.RawConfigParser()
+ mongocnf = os.path.expanduser('~/.mongodb.cnf')
+
+ try:
+ config.readfp(open(mongocnf))
+ creds = dict(
+ user=config.get('client', 'user'),
+ password=config.get('client', 'pass')
+ )
+ except (configparser.NoOptionError, IOError):
+ return False
+
+ return creds
+
+
+# =========================================
+# Module execution.
+#
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ login_user=dict(default=None),
+ login_password=dict(default=None, no_log=True),
+ login_host=dict(default='localhost'),
+ login_port=dict(default=27017, type='int'),
+ login_database=dict(default=None),
+ replica_set=dict(default=None),
+ param=dict(required=True),
+ value=dict(required=True),
+ param_type=dict(default="str", choices=['str', 'int']),
+ ssl=dict(default=False, type='bool'),
+ )
+ )
+
+ if not pymongo_found:
+ module.fail_json(msg=missing_required_lib('pymongo'))
+
+ login_user = module.params['login_user']
+ login_password = module.params['login_password']
+ login_host = module.params['login_host']
+ login_port = module.params['login_port']
+ login_database = module.params['login_database']
+
+ replica_set = module.params['replica_set']
+ ssl = module.params['ssl']
+
+ param = module.params['param']
+ param_type = module.params['param_type']
+ value = module.params['value']
+
+ # Verify parameter is coherent with specified type
+ try:
+ if param_type == 'int':
+ value = int(value)
+ except ValueError:
+ module.fail_json(msg="value '%s' is not %s" % (value, param_type))
+
+ try:
+ if replica_set:
+ client = MongoClient(login_host, int(login_port), replicaset=replica_set, ssl=ssl)
+ else:
+ client = MongoClient(login_host, int(login_port), ssl=ssl)
+
+ if login_user is None and login_password is None:
+ mongocnf_creds = load_mongocnf()
+ if mongocnf_creds is not False:
+ login_user = mongocnf_creds['user']
+ login_password = mongocnf_creds['password']
+ elif login_password is None or login_user is None:
+ module.fail_json(msg='when supplying login arguments, both login_user and login_password must be provided')
+
+ if login_user is not None and login_password is not None:
+ client.admin.authenticate(login_user, login_password, source=login_database)
+
+ except ConnectionFailure as e:
+ module.fail_json(msg='unable to connect to database: %s' % to_native(e), exception=traceback.format_exc())
+
+ db = client.admin
+
+ try:
+ after_value = db.command("setParameter", **{param: value})
+ except OperationFailure as e:
+ module.fail_json(msg="unable to change parameter: %s" % to_native(e), exception=traceback.format_exc())
+
+ if "was" not in after_value:
+ module.exit_json(changed=True, msg="Unable to determine old value, assume it changed.")
+ else:
+ module.exit_json(changed=(value != after_value["was"]), before=after_value["was"],
+ after=value)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/mongodb_user.py b/test/support/integration/plugins/modules/mongodb_user.py
new file mode 100644
index 00000000..362b3aa4
--- /dev/null
+++ b/test/support/integration/plugins/modules/mongodb_user.py
@@ -0,0 +1,474 @@
+#!/usr/bin/python
+
+# (c) 2012, Elliott Foster <elliott@fourkitchens.com>
+# Sponsored by Four Kitchens http://fourkitchens.com.
+# (c) 2014, Epic Games, Inc.
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: mongodb_user
+short_description: Adds or removes a user from a MongoDB database
+description:
+ - Adds or removes a user from a MongoDB database.
+version_added: "1.1"
+options:
+ login_user:
+ description:
+ - The MongoDB username used to authenticate with.
+ type: str
+ login_password:
+ description:
+ - The login user's password used to authenticate with.
+ type: str
+ login_host:
+ description:
+ - The host running the database.
+ default: localhost
+ type: str
+ login_port:
+ description:
+ - The MongoDB port to connect to.
+ default: '27017'
+ type: str
+ login_database:
+ version_added: "2.0"
+ description:
+ - The database where login credentials are stored.
+ type: str
+ replica_set:
+ version_added: "1.6"
+ description:
+ - Replica set to connect to (automatically connects to primary for writes).
+ type: str
+ database:
+ description:
+ - The name of the database to add/remove the user from.
+ required: true
+ type: str
+ aliases: [db]
+ name:
+ description:
+ - The name of the user to add or remove.
+ required: true
+ aliases: [user]
+ type: str
+ password:
+ description:
+ - The password to use for the user.
+ type: str
+ aliases: [pass]
+ ssl:
+ version_added: "1.8"
+ description:
+ - Whether to use an SSL connection when connecting to the database.
+ type: bool
+ ssl_cert_reqs:
+ version_added: "2.2"
+ description:
+ - Specifies whether a certificate is required from the other side of the connection,
+ and whether it will be validated if provided.
+ default: CERT_REQUIRED
+ choices: [CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED]
+ type: str
+ roles:
+ version_added: "1.3"
+ type: list
+ elements: raw
+ description:
+ - >
+ The database user roles valid values could either be one or more of the following strings:
+ 'read', 'readWrite', 'dbAdmin', 'userAdmin', 'clusterAdmin', 'readAnyDatabase', 'readWriteAnyDatabase', 'userAdminAnyDatabase',
+ 'dbAdminAnyDatabase'
+ - "Or the following dictionary '{ db: DATABASE_NAME, role: ROLE_NAME }'."
+ - "This param requires pymongo 2.5+. If it is a string, mongodb 2.4+ is also required. If it is a dictionary, mongo 2.6+ is required."
+ state:
+ description:
+ - The database user state.
+ default: present
+ choices: [absent, present]
+ type: str
+ update_password:
+ default: always
+ choices: [always, on_create]
+ version_added: "2.1"
+ description:
+ - C(always) will update passwords if they differ.
+ - C(on_create) will only set the password for newly created users.
+ type: str
+
+notes:
+ - Requires the pymongo Python package on the remote host, version 2.4.2+. This
+ can be installed using pip or the OS package manager. @see http://api.mongodb.org/python/current/installation.html
+requirements: [ "pymongo" ]
+author:
+ - "Elliott Foster (@elliotttf)"
+ - "Julien Thebault (@Lujeni)"
+'''
+
+EXAMPLES = '''
+- name: Create 'burgers' database user with name 'bob' and password '12345'.
+ mongodb_user:
+ database: burgers
+ name: bob
+ password: 12345
+ state: present
+
+- name: Create a database user via SSL (MongoDB must be compiled with the SSL option and configured properly)
+ mongodb_user:
+ database: burgers
+ name: bob
+ password: 12345
+ state: present
+ ssl: True
+
+- name: Delete 'burgers' database user with name 'bob'.
+ mongodb_user:
+ database: burgers
+ name: bob
+ state: absent
+
+- name: Define more users with various specific roles (if not defined, no roles is assigned, and the user will be added via pre mongo 2.2 style)
+ mongodb_user:
+ database: burgers
+ name: ben
+ password: 12345
+ roles: read
+ state: present
+
+- name: Define roles
+ mongodb_user:
+ database: burgers
+ name: jim
+ password: 12345
+ roles: readWrite,dbAdmin,userAdmin
+ state: present
+
+- name: Define roles
+ mongodb_user:
+ database: burgers
+ name: joe
+ password: 12345
+ roles: readWriteAnyDatabase
+ state: present
+
+- name: Add a user to database in a replica set, the primary server is automatically discovered and written to
+ mongodb_user:
+ database: burgers
+ name: bob
+ replica_set: belcher
+ password: 12345
+ roles: readWriteAnyDatabase
+ state: present
+
+# add a user 'oplog_reader' with read only access to the 'local' database on the replica_set 'belcher'. This is useful for oplog access (MONGO_OPLOG_URL).
+# please notice the credentials must be added to the 'admin' database because the 'local' database is not synchronized and can't receive user credentials
+# To login with such user, the connection string should be MONGO_OPLOG_URL="mongodb://oplog_reader:oplog_reader_password@server1,server2/local?authSource=admin"
+# This syntax requires mongodb 2.6+ and pymongo 2.5+
+- name: Roles as a dictionary
+ mongodb_user:
+ login_user: root
+ login_password: root_password
+ database: admin
+ user: oplog_reader
+ password: oplog_reader_password
+ state: present
+ replica_set: belcher
+ roles:
+ - db: local
+ role: read
+
+'''
+
+RETURN = '''
+user:
+ description: The name of the user to add or remove.
+ returned: success
+ type: str
+'''
+
+import os
+import ssl as ssl_lib
+import traceback
+from distutils.version import LooseVersion
+from operator import itemgetter
+
+try:
+ from pymongo.errors import ConnectionFailure
+ from pymongo.errors import OperationFailure
+ from pymongo import version as PyMongoVersion
+ from pymongo import MongoClient
+except ImportError:
+ try: # for older PyMongo 2.2
+ from pymongo import Connection as MongoClient
+ except ImportError:
+ pymongo_found = False
+ else:
+ pymongo_found = True
+else:
+ pymongo_found = True
+
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+from ansible.module_utils.six import binary_type, text_type
+from ansible.module_utils.six.moves import configparser
+from ansible.module_utils._text import to_native
+
+
+# =========================================
+# MongoDB module specific support methods.
+#
+
+def check_compatibility(module, client):
+ """Check the compatibility between the driver and the database.
+
+ See: https://docs.mongodb.com/ecosystem/drivers/driver-compatibility-reference/#python-driver-compatibility
+
+ Args:
+ module: Ansible module.
+ client (cursor): Mongodb cursor on admin database.
+ """
+ loose_srv_version = LooseVersion(client.server_info()['version'])
+ loose_driver_version = LooseVersion(PyMongoVersion)
+
+ if loose_srv_version >= LooseVersion('3.2') and loose_driver_version < LooseVersion('3.2'):
+ module.fail_json(msg=' (Note: you must use pymongo 3.2+ with MongoDB >= 3.2)')
+
+ elif loose_srv_version >= LooseVersion('3.0') and loose_driver_version <= LooseVersion('2.8'):
+ module.fail_json(msg=' (Note: you must use pymongo 2.8+ with MongoDB 3.0)')
+
+ elif loose_srv_version >= LooseVersion('2.6') and loose_driver_version <= LooseVersion('2.7'):
+ module.fail_json(msg=' (Note: you must use pymongo 2.7+ with MongoDB 2.6)')
+
+ elif LooseVersion(PyMongoVersion) <= LooseVersion('2.5'):
+ module.fail_json(msg=' (Note: you must be on mongodb 2.4+ and pymongo 2.5+ to use the roles param)')
+
+
+def user_find(client, user, db_name):
+ """Check if the user exists.
+
+ Args:
+ client (cursor): Mongodb cursor on admin database.
+ user (str): User to check.
+ db_name (str): User's database.
+
+ Returns:
+ dict: when user exists, False otherwise.
+ """
+ for mongo_user in client["admin"].system.users.find():
+ if mongo_user['user'] == user:
+ # NOTE: there is no 'db' field in mongo 2.4.
+ if 'db' not in mongo_user:
+ return mongo_user
+
+ if mongo_user["db"] == db_name:
+ return mongo_user
+ return False
+
+
+def user_add(module, client, db_name, user, password, roles):
+ # pymongo's user_add is a _create_or_update_user so we won't know if it was changed or updated
+ # without reproducing a lot of the logic in database.py of pymongo
+ db = client[db_name]
+
+ if roles is None:
+ db.add_user(user, password, False)
+ else:
+ db.add_user(user, password, None, roles=roles)
+
+
+def user_remove(module, client, db_name, user):
+ exists = user_find(client, user, db_name)
+ if exists:
+ if module.check_mode:
+ module.exit_json(changed=True, user=user)
+ db = client[db_name]
+ db.remove_user(user)
+ else:
+ module.exit_json(changed=False, user=user)
+
+
+def load_mongocnf():
+ config = configparser.RawConfigParser()
+ mongocnf = os.path.expanduser('~/.mongodb.cnf')
+
+ try:
+ config.readfp(open(mongocnf))
+ creds = dict(
+ user=config.get('client', 'user'),
+ password=config.get('client', 'pass')
+ )
+ except (configparser.NoOptionError, IOError):
+ return False
+
+ return creds
+
+
+def check_if_roles_changed(uinfo, roles, db_name):
+ # We must be aware of users which can read the oplog on a replicaset
+ # Such users must have access to the local DB, but since this DB does not store users credentials
+ # and is not synchronized among replica sets, the user must be stored on the admin db
+ # Therefore their structure is the following :
+ # {
+ # "_id" : "admin.oplog_reader",
+ # "user" : "oplog_reader",
+ # "db" : "admin", # <-- admin DB
+ # "roles" : [
+ # {
+ # "role" : "read",
+ # "db" : "local" # <-- local DB
+ # }
+ # ]
+ # }
+
+ def make_sure_roles_are_a_list_of_dict(roles, db_name):
+ output = list()
+ for role in roles:
+ if isinstance(role, (binary_type, text_type)):
+ new_role = {"role": role, "db": db_name}
+ output.append(new_role)
+ else:
+ output.append(role)
+ return output
+
+ roles_as_list_of_dict = make_sure_roles_are_a_list_of_dict(roles, db_name)
+ uinfo_roles = uinfo.get('roles', [])
+
+ if sorted(roles_as_list_of_dict, key=itemgetter('db')) == sorted(uinfo_roles, key=itemgetter('db')):
+ return False
+ return True
+
+
+# =========================================
+# Module execution.
+#
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ login_user=dict(default=None),
+ login_password=dict(default=None, no_log=True),
+ login_host=dict(default='localhost'),
+ login_port=dict(default='27017'),
+ login_database=dict(default=None),
+ replica_set=dict(default=None),
+ database=dict(required=True, aliases=['db']),
+ name=dict(required=True, aliases=['user']),
+ password=dict(aliases=['pass'], no_log=True),
+ ssl=dict(default=False, type='bool'),
+ roles=dict(default=None, type='list', elements='raw'),
+ state=dict(default='present', choices=['absent', 'present']),
+ update_password=dict(default="always", choices=["always", "on_create"]),
+ ssl_cert_reqs=dict(default='CERT_REQUIRED', choices=['CERT_NONE', 'CERT_OPTIONAL', 'CERT_REQUIRED']),
+ ),
+ supports_check_mode=True
+ )
+
+ if not pymongo_found:
+ module.fail_json(msg=missing_required_lib('pymongo'))
+
+ login_user = module.params['login_user']
+ login_password = module.params['login_password']
+ login_host = module.params['login_host']
+ login_port = module.params['login_port']
+ login_database = module.params['login_database']
+
+ replica_set = module.params['replica_set']
+ db_name = module.params['database']
+ user = module.params['name']
+ password = module.params['password']
+ ssl = module.params['ssl']
+ roles = module.params['roles'] or []
+ state = module.params['state']
+ update_password = module.params['update_password']
+
+ try:
+ connection_params = {
+ "host": login_host,
+ "port": int(login_port),
+ }
+
+ if replica_set:
+ connection_params["replicaset"] = replica_set
+
+ if ssl:
+ connection_params["ssl"] = ssl
+ connection_params["ssl_cert_reqs"] = getattr(ssl_lib, module.params['ssl_cert_reqs'])
+
+ client = MongoClient(**connection_params)
+
+ # NOTE: this check must be done ASAP.
+ # We doesn't need to be authenticated (this ability has lost in PyMongo 3.6)
+ if LooseVersion(PyMongoVersion) <= LooseVersion('3.5'):
+ check_compatibility(module, client)
+
+ if login_user is None and login_password is None:
+ mongocnf_creds = load_mongocnf()
+ if mongocnf_creds is not False:
+ login_user = mongocnf_creds['user']
+ login_password = mongocnf_creds['password']
+ elif login_password is None or login_user is None:
+ module.fail_json(msg='when supplying login arguments, both login_user and login_password must be provided')
+
+ if login_user is not None and login_password is not None:
+ client.admin.authenticate(login_user, login_password, source=login_database)
+ elif LooseVersion(PyMongoVersion) >= LooseVersion('3.0'):
+ if db_name != "admin":
+ module.fail_json(msg='The localhost login exception only allows the first admin account to be created')
+ # else: this has to be the first admin user added
+
+ except Exception as e:
+ module.fail_json(msg='unable to connect to database: %s' % to_native(e), exception=traceback.format_exc())
+
+ if state == 'present':
+ if password is None and update_password == 'always':
+ module.fail_json(msg='password parameter required when adding a user unless update_password is set to on_create')
+
+ try:
+ if update_password != 'always':
+ uinfo = user_find(client, user, db_name)
+ if uinfo:
+ password = None
+ if not check_if_roles_changed(uinfo, roles, db_name):
+ module.exit_json(changed=False, user=user)
+
+ if module.check_mode:
+ module.exit_json(changed=True, user=user)
+
+ user_add(module, client, db_name, user, password, roles)
+ except Exception as e:
+ module.fail_json(msg='Unable to add or update user: %s' % to_native(e), exception=traceback.format_exc())
+ finally:
+ try:
+ client.close()
+ except Exception:
+ pass
+ # Here we can check password change if mongo provide a query for that : https://jira.mongodb.org/browse/SERVER-22848
+ # newuinfo = user_find(client, user, db_name)
+ # if uinfo['role'] == newuinfo['role'] and CheckPasswordHere:
+ # module.exit_json(changed=False, user=user)
+
+ elif state == 'absent':
+ try:
+ user_remove(module, client, db_name, user)
+ except Exception as e:
+ module.fail_json(msg='Unable to remove user: %s' % to_native(e), exception=traceback.format_exc())
+ finally:
+ try:
+ client.close()
+ except Exception:
+ pass
+ module.exit_json(changed=True, user=user)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/pids.py b/test/support/integration/plugins/modules/pids.py
new file mode 100644
index 00000000..4cbf45a9
--- /dev/null
+++ b/test/support/integration/plugins/modules/pids.py
@@ -0,0 +1,89 @@
+#!/usr/bin/python
+# Copyright: (c) 2019, Saranya Sridharan
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = '''
+module: pids
+version_added: 2.8
+description: "Retrieves a list of PIDs of given process name in Ansible controller/controlled machines.Returns an empty list if no process in that name exists."
+short_description: "Retrieves process IDs list if the process is running otherwise return empty list"
+author:
+ - Saranya Sridharan (@saranyasridharan)
+requirements:
+ - psutil(python module)
+options:
+ name:
+ description: the name of the process you want to get PID for.
+ required: true
+ type: str
+'''
+
+EXAMPLES = '''
+# Pass the process name
+- name: Getting process IDs of the process
+ pids:
+ name: python
+ register: pids_of_python
+
+- name: Printing the process IDs obtained
+ debug:
+ msg: "PIDS of python:{{pids_of_python.pids|join(',')}}"
+'''
+
+RETURN = '''
+pids:
+ description: Process IDs of the given process
+ returned: list of none, one, or more process IDs
+ type: list
+ sample: [100,200]
+'''
+
+from ansible.module_utils.basic import AnsibleModule
+try:
+ import psutil
+ HAS_PSUTIL = True
+except ImportError:
+ HAS_PSUTIL = False
+
+
+def compare_lower(a, b):
+ if a is None or b is None:
+ # this could just be "return False" but would lead to surprising behavior if both a and b are None
+ return a == b
+
+ return a.lower() == b.lower()
+
+
+def get_pid(name):
+ pids = []
+
+ for proc in psutil.process_iter(attrs=['name', 'cmdline']):
+ if compare_lower(proc.info['name'], name) or \
+ proc.info['cmdline'] and compare_lower(proc.info['cmdline'][0], name):
+ pids.append(proc.pid)
+
+ return pids
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ name=dict(required=True, type="str"),
+ ),
+ supports_check_mode=True,
+ )
+ if not HAS_PSUTIL:
+ module.fail_json(msg="Missing required 'psutil' python module. Try installing it with: pip install psutil")
+ name = module.params["name"]
+ response = dict(pids=get_pid(name))
+ module.exit_json(**response)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/pkgng.py b/test/support/integration/plugins/modules/pkgng.py
new file mode 100644
index 00000000..11363479
--- /dev/null
+++ b/test/support/integration/plugins/modules/pkgng.py
@@ -0,0 +1,406 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# (c) 2013, bleader
+# Written by bleader <bleader@ratonland.org>
+# Based on pkgin module written by Shaun Zinck <shaun.zinck at gmail.com>
+# that was based on pacman module written by Afterburn <https://github.com/afterburn>
+# that was based on apt module written by Matthew Williams <matthew@flowroute.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: pkgng
+short_description: Package manager for FreeBSD >= 9.0
+description:
+ - Manage binary packages for FreeBSD using 'pkgng' which is available in versions after 9.0.
+version_added: "1.2"
+options:
+ name:
+ description:
+ - Name or list of names of packages to install/remove.
+ required: true
+ state:
+ description:
+ - State of the package.
+ - 'Note: "latest" added in 2.7'
+ choices: [ 'present', 'latest', 'absent' ]
+ required: false
+ default: present
+ cached:
+ description:
+ - Use local package base instead of fetching an updated one.
+ type: bool
+ required: false
+ default: no
+ annotation:
+ description:
+ - A comma-separated list of keyvalue-pairs of the form
+ C(<+/-/:><key>[=<value>]). A C(+) denotes adding an annotation, a
+ C(-) denotes removing an annotation, and C(:) denotes modifying an
+ annotation.
+ If setting or modifying annotations, a value must be provided.
+ required: false
+ version_added: "1.6"
+ pkgsite:
+ description:
+ - For pkgng versions before 1.1.4, specify packagesite to use
+ for downloading packages. If not specified, use settings from
+ C(/usr/local/etc/pkg.conf).
+ - For newer pkgng versions, specify a the name of a repository
+ configured in C(/usr/local/etc/pkg/repos).
+ required: false
+ rootdir:
+ description:
+ - For pkgng versions 1.5 and later, pkg will install all packages
+ within the specified root directory.
+ - Can not be used together with I(chroot) or I(jail) options.
+ required: false
+ chroot:
+ version_added: "2.1"
+ description:
+ - Pkg will chroot in the specified environment.
+ - Can not be used together with I(rootdir) or I(jail) options.
+ required: false
+ jail:
+ version_added: "2.4"
+ description:
+ - Pkg will execute in the given jail name or id.
+ - Can not be used together with I(chroot) or I(rootdir) options.
+ autoremove:
+ version_added: "2.2"
+ description:
+ - Remove automatically installed packages which are no longer needed.
+ required: false
+ type: bool
+ default: no
+author: "bleader (@bleader)"
+notes:
+ - When using pkgsite, be careful that already in cache packages won't be downloaded again.
+ - When used with a `loop:` each package will be processed individually,
+ it is much more efficient to pass the list directly to the `name` option.
+'''
+
+EXAMPLES = '''
+- name: Install package foo
+ pkgng:
+ name: foo
+ state: present
+
+- name: Annotate package foo and bar
+ pkgng:
+ name: foo,bar
+ annotation: '+test1=baz,-test2,:test3=foobar'
+
+- name: Remove packages foo and bar
+ pkgng:
+ name: foo,bar
+ state: absent
+
+# "latest" support added in 2.7
+- name: Upgrade package baz
+ pkgng:
+ name: baz
+ state: latest
+'''
+
+
+import re
+from ansible.module_utils.basic import AnsibleModule
+
+
+def query_package(module, pkgng_path, name, dir_arg):
+
+ rc, out, err = module.run_command("%s %s info -g -e %s" % (pkgng_path, dir_arg, name))
+
+ if rc == 0:
+ return True
+
+ return False
+
+
+def query_update(module, pkgng_path, name, dir_arg, old_pkgng, pkgsite):
+
+ # Check to see if a package upgrade is available.
+ # rc = 0, no updates available or package not installed
+ # rc = 1, updates available
+ if old_pkgng:
+ rc, out, err = module.run_command("%s %s upgrade -g -n %s" % (pkgsite, pkgng_path, name))
+ else:
+ rc, out, err = module.run_command("%s %s upgrade %s -g -n %s" % (pkgng_path, dir_arg, pkgsite, name))
+
+ if rc == 1:
+ return True
+
+ return False
+
+
+def pkgng_older_than(module, pkgng_path, compare_version):
+
+ rc, out, err = module.run_command("%s -v" % pkgng_path)
+ version = [int(x) for x in re.split(r'[\._]', out)]
+
+ i = 0
+ new_pkgng = True
+ while compare_version[i] == version[i]:
+ i += 1
+ if i == min(len(compare_version), len(version)):
+ break
+ else:
+ if compare_version[i] > version[i]:
+ new_pkgng = False
+ return not new_pkgng
+
+
+def remove_packages(module, pkgng_path, packages, dir_arg):
+
+ remove_c = 0
+ # Using a for loop in case of error, we can report the package that failed
+ for package in packages:
+ # Query the package first, to see if we even need to remove
+ if not query_package(module, pkgng_path, package, dir_arg):
+ continue
+
+ if not module.check_mode:
+ rc, out, err = module.run_command("%s %s delete -y %s" % (pkgng_path, dir_arg, package))
+
+ if not module.check_mode and query_package(module, pkgng_path, package, dir_arg):
+ module.fail_json(msg="failed to remove %s: %s" % (package, out))
+
+ remove_c += 1
+
+ if remove_c > 0:
+
+ return (True, "removed %s package(s)" % remove_c)
+
+ return (False, "package(s) already absent")
+
+
+def install_packages(module, pkgng_path, packages, cached, pkgsite, dir_arg, state):
+
+ install_c = 0
+
+ # as of pkg-1.1.4, PACKAGESITE is deprecated in favor of repository definitions
+ # in /usr/local/etc/pkg/repos
+ old_pkgng = pkgng_older_than(module, pkgng_path, [1, 1, 4])
+ if pkgsite != "":
+ if old_pkgng:
+ pkgsite = "PACKAGESITE=%s" % (pkgsite)
+ else:
+ pkgsite = "-r %s" % (pkgsite)
+
+ # This environment variable skips mid-install prompts,
+ # setting them to their default values.
+ batch_var = 'env BATCH=yes'
+
+ if not module.check_mode and not cached:
+ if old_pkgng:
+ rc, out, err = module.run_command("%s %s update" % (pkgsite, pkgng_path))
+ else:
+ rc, out, err = module.run_command("%s %s update" % (pkgng_path, dir_arg))
+ if rc != 0:
+ module.fail_json(msg="Could not update catalogue [%d]: %s %s" % (rc, out, err))
+
+ for package in packages:
+ already_installed = query_package(module, pkgng_path, package, dir_arg)
+ if already_installed and state == "present":
+ continue
+
+ update_available = query_update(module, pkgng_path, package, dir_arg, old_pkgng, pkgsite)
+ if not update_available and already_installed and state == "latest":
+ continue
+
+ if not module.check_mode:
+ if already_installed:
+ action = "upgrade"
+ else:
+ action = "install"
+ if old_pkgng:
+ rc, out, err = module.run_command("%s %s %s %s -g -U -y %s" % (batch_var, pkgsite, pkgng_path, action, package))
+ else:
+ rc, out, err = module.run_command("%s %s %s %s %s -g -U -y %s" % (batch_var, pkgng_path, dir_arg, action, pkgsite, package))
+
+ if not module.check_mode and not query_package(module, pkgng_path, package, dir_arg):
+ module.fail_json(msg="failed to %s %s: %s" % (action, package, out), stderr=err)
+
+ install_c += 1
+
+ if install_c > 0:
+ return (True, "added %s package(s)" % (install_c))
+
+ return (False, "package(s) already %s" % (state))
+
+
+def annotation_query(module, pkgng_path, package, tag, dir_arg):
+ rc, out, err = module.run_command("%s %s info -g -A %s" % (pkgng_path, dir_arg, package))
+ match = re.search(r'^\s*(?P<tag>%s)\s*:\s*(?P<value>\w+)' % tag, out, flags=re.MULTILINE)
+ if match:
+ return match.group('value')
+ return False
+
+
+def annotation_add(module, pkgng_path, package, tag, value, dir_arg):
+ _value = annotation_query(module, pkgng_path, package, tag, dir_arg)
+ if not _value:
+ # Annotation does not exist, add it.
+ rc, out, err = module.run_command('%s %s annotate -y -A %s %s "%s"'
+ % (pkgng_path, dir_arg, package, tag, value))
+ if rc != 0:
+ module.fail_json(msg="could not annotate %s: %s"
+ % (package, out), stderr=err)
+ return True
+ elif _value != value:
+ # Annotation exists, but value differs
+ module.fail_json(
+ mgs="failed to annotate %s, because %s is already set to %s, but should be set to %s"
+ % (package, tag, _value, value))
+ return False
+ else:
+ # Annotation exists, nothing to do
+ return False
+
+
+def annotation_delete(module, pkgng_path, package, tag, value, dir_arg):
+ _value = annotation_query(module, pkgng_path, package, tag, dir_arg)
+ if _value:
+ rc, out, err = module.run_command('%s %s annotate -y -D %s %s'
+ % (pkgng_path, dir_arg, package, tag))
+ if rc != 0:
+ module.fail_json(msg="could not delete annotation to %s: %s"
+ % (package, out), stderr=err)
+ return True
+ return False
+
+
+def annotation_modify(module, pkgng_path, package, tag, value, dir_arg):
+ _value = annotation_query(module, pkgng_path, package, tag, dir_arg)
+ if not value:
+ # No such tag
+ module.fail_json(msg="could not change annotation to %s: tag %s does not exist"
+ % (package, tag))
+ elif _value == value:
+ # No change in value
+ return False
+ else:
+ rc, out, err = module.run_command('%s %s annotate -y -M %s %s "%s"'
+ % (pkgng_path, dir_arg, package, tag, value))
+ if rc != 0:
+ module.fail_json(msg="could not change annotation annotation to %s: %s"
+ % (package, out), stderr=err)
+ return True
+
+
+def annotate_packages(module, pkgng_path, packages, annotation, dir_arg):
+ annotate_c = 0
+ annotations = map(lambda _annotation:
+ re.match(r'(?P<operation>[\+-:])(?P<tag>\w+)(=(?P<value>\w+))?',
+ _annotation).groupdict(),
+ re.split(r',', annotation))
+
+ operation = {
+ '+': annotation_add,
+ '-': annotation_delete,
+ ':': annotation_modify
+ }
+
+ for package in packages:
+ for _annotation in annotations:
+ if operation[_annotation['operation']](module, pkgng_path, package, _annotation['tag'], _annotation['value']):
+ annotate_c += 1
+
+ if annotate_c > 0:
+ return (True, "added %s annotations." % annotate_c)
+ return (False, "changed no annotations")
+
+
+def autoremove_packages(module, pkgng_path, dir_arg):
+ rc, out, err = module.run_command("%s %s autoremove -n" % (pkgng_path, dir_arg))
+
+ autoremove_c = 0
+
+ match = re.search('^Deinstallation has been requested for the following ([0-9]+) packages', out, re.MULTILINE)
+ if match:
+ autoremove_c = int(match.group(1))
+
+ if autoremove_c == 0:
+ return False, "no package(s) to autoremove"
+
+ if not module.check_mode:
+ rc, out, err = module.run_command("%s %s autoremove -y" % (pkgng_path, dir_arg))
+
+ return True, "autoremoved %d package(s)" % (autoremove_c)
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ state=dict(default="present", choices=["present", "latest", "absent"], required=False),
+ name=dict(aliases=["pkg"], required=True, type='list'),
+ cached=dict(default=False, type='bool'),
+ annotation=dict(default="", required=False),
+ pkgsite=dict(default="", required=False),
+ rootdir=dict(default="", required=False, type='path'),
+ chroot=dict(default="", required=False, type='path'),
+ jail=dict(default="", required=False, type='str'),
+ autoremove=dict(default=False, type='bool')),
+ supports_check_mode=True,
+ mutually_exclusive=[["rootdir", "chroot", "jail"]])
+
+ pkgng_path = module.get_bin_path('pkg', True)
+
+ p = module.params
+
+ pkgs = p["name"]
+
+ changed = False
+ msgs = []
+ dir_arg = ""
+
+ if p["rootdir"] != "":
+ old_pkgng = pkgng_older_than(module, pkgng_path, [1, 5, 0])
+ if old_pkgng:
+ module.fail_json(msg="To use option 'rootdir' pkg version must be 1.5 or greater")
+ else:
+ dir_arg = "--rootdir %s" % (p["rootdir"])
+
+ if p["chroot"] != "":
+ dir_arg = '--chroot %s' % (p["chroot"])
+
+ if p["jail"] != "":
+ dir_arg = '--jail %s' % (p["jail"])
+
+ if p["state"] in ("present", "latest"):
+ _changed, _msg = install_packages(module, pkgng_path, pkgs, p["cached"], p["pkgsite"], dir_arg, p["state"])
+ changed = changed or _changed
+ msgs.append(_msg)
+
+ elif p["state"] == "absent":
+ _changed, _msg = remove_packages(module, pkgng_path, pkgs, dir_arg)
+ changed = changed or _changed
+ msgs.append(_msg)
+
+ if p["autoremove"]:
+ _changed, _msg = autoremove_packages(module, pkgng_path, dir_arg)
+ changed = changed or _changed
+ msgs.append(_msg)
+
+ if p["annotation"]:
+ _changed, _msg = annotate_packages(module, pkgng_path, pkgs, p["annotation"], dir_arg)
+ changed = changed or _changed
+ msgs.append(_msg)
+
+ module.exit_json(changed=changed, msg=", ".join(msgs))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/postgresql_db.py b/test/support/integration/plugins/modules/postgresql_db.py
new file mode 100644
index 00000000..40858d99
--- /dev/null
+++ b/test/support/integration/plugins/modules/postgresql_db.py
@@ -0,0 +1,657 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: postgresql_db
+short_description: Add or remove PostgreSQL databases from a remote host.
+description:
+ - Add or remove PostgreSQL databases from a remote host.
+version_added: '0.6'
+options:
+ name:
+ description:
+ - Name of the database to add or remove
+ type: str
+ required: true
+ aliases: [ db ]
+ port:
+ description:
+ - Database port to connect (if needed)
+ type: int
+ default: 5432
+ aliases:
+ - login_port
+ owner:
+ description:
+ - Name of the role to set as owner of the database
+ type: str
+ template:
+ description:
+ - Template used to create the database
+ type: str
+ encoding:
+ description:
+ - Encoding of the database
+ type: str
+ lc_collate:
+ description:
+ - Collation order (LC_COLLATE) to use in the database. Must match collation order of template database unless C(template0) is used as template.
+ type: str
+ lc_ctype:
+ description:
+ - Character classification (LC_CTYPE) to use in the database (e.g. lower, upper, ...) Must match LC_CTYPE of template database unless C(template0)
+ is used as template.
+ type: str
+ session_role:
+ description:
+ - Switch to session_role after connecting. The specified session_role must be a role that the current login_user is a member of.
+ - Permissions checking for SQL commands is carried out as though the session_role were the one that had logged in originally.
+ type: str
+ version_added: '2.8'
+ state:
+ description:
+ - The database state.
+ - C(present) implies that the database should be created if necessary.
+ - C(absent) implies that the database should be removed if present.
+ - C(dump) requires a target definition to which the database will be backed up. (Added in Ansible 2.4)
+ Note that in some PostgreSQL versions of pg_dump, which is an embedded PostgreSQL utility and is used by the module,
+ returns rc 0 even when errors occurred (e.g. the connection is forbidden by pg_hba.conf, etc.),
+ so the module returns changed=True but the dump has not actually been done. Please, be sure that your version of
+ pg_dump returns rc 1 in this case.
+ - C(restore) also requires a target definition from which the database will be restored. (Added in Ansible 2.4)
+ - The format of the backup will be detected based on the target name.
+ - Supported compression formats for dump and restore include C(.pgc), C(.bz2), C(.gz) and C(.xz)
+ - Supported formats for dump and restore include C(.sql) and C(.tar)
+ type: str
+ choices: [ absent, dump, present, restore ]
+ default: present
+ target:
+ description:
+ - File to back up or restore from.
+ - Used when I(state) is C(dump) or C(restore).
+ type: path
+ version_added: '2.4'
+ target_opts:
+ description:
+ - Further arguments for pg_dump or pg_restore.
+ - Used when I(state) is C(dump) or C(restore).
+ type: str
+ version_added: '2.4'
+ maintenance_db:
+ description:
+ - The value specifies the initial database (which is also called as maintenance DB) that Ansible connects to.
+ type: str
+ default: postgres
+ version_added: '2.5'
+ conn_limit:
+ description:
+ - Specifies the database connection limit.
+ type: str
+ version_added: '2.8'
+ tablespace:
+ description:
+ - The tablespace to set for the database
+ U(https://www.postgresql.org/docs/current/sql-alterdatabase.html).
+ - If you want to move the database back to the default tablespace,
+ explicitly set this to pg_default.
+ type: path
+ version_added: '2.9'
+ dump_extra_args:
+ description:
+ - Provides additional arguments when I(state) is C(dump).
+ - Cannot be used with dump-file-format-related arguments like ``--format=d``.
+ type: str
+ version_added: '2.10'
+seealso:
+- name: CREATE DATABASE reference
+ description: Complete reference of the CREATE DATABASE command documentation.
+ link: https://www.postgresql.org/docs/current/sql-createdatabase.html
+- name: DROP DATABASE reference
+ description: Complete reference of the DROP DATABASE command documentation.
+ link: https://www.postgresql.org/docs/current/sql-dropdatabase.html
+- name: pg_dump reference
+ description: Complete reference of pg_dump documentation.
+ link: https://www.postgresql.org/docs/current/app-pgdump.html
+- name: pg_restore reference
+ description: Complete reference of pg_restore documentation.
+ link: https://www.postgresql.org/docs/current/app-pgrestore.html
+- module: postgresql_tablespace
+- module: postgresql_info
+- module: postgresql_ping
+notes:
+- State C(dump) and C(restore) don't require I(psycopg2) since version 2.8.
+author: "Ansible Core Team"
+extends_documentation_fragment:
+- postgres
+'''
+
+EXAMPLES = r'''
+- name: Create a new database with name "acme"
+ postgresql_db:
+ name: acme
+
+# Note: If a template different from "template0" is specified, encoding and locale settings must match those of the template.
+- name: Create a new database with name "acme" and specific encoding and locale # settings.
+ postgresql_db:
+ name: acme
+ encoding: UTF-8
+ lc_collate: de_DE.UTF-8
+ lc_ctype: de_DE.UTF-8
+ template: template0
+
+# Note: Default limit for the number of concurrent connections to a specific database is "-1", which means "unlimited"
+- name: Create a new database with name "acme" which has a limit of 100 concurrent connections
+ postgresql_db:
+ name: acme
+ conn_limit: "100"
+
+- name: Dump an existing database to a file
+ postgresql_db:
+ name: acme
+ state: dump
+ target: /tmp/acme.sql
+
+- name: Dump an existing database to a file excluding the test table
+ postgresql_db:
+ name: acme
+ state: dump
+ target: /tmp/acme.sql
+ dump_extra_args: --exclude-table=test
+
+- name: Dump an existing database to a file (with compression)
+ postgresql_db:
+ name: acme
+ state: dump
+ target: /tmp/acme.sql.gz
+
+- name: Dump a single schema for an existing database
+ postgresql_db:
+ name: acme
+ state: dump
+ target: /tmp/acme.sql
+ target_opts: "-n public"
+
+# Note: In the example below, if database foo exists and has another tablespace
+# the tablespace will be changed to foo. Access to the database will be locked
+# until the copying of database files is finished.
+- name: Create a new database called foo in tablespace bar
+ postgresql_db:
+ name: foo
+ tablespace: bar
+'''
+
+RETURN = r'''
+executed_commands:
+ description: List of commands which tried to run.
+ returned: always
+ type: list
+ sample: ["CREATE DATABASE acme"]
+ version_added: '2.10'
+'''
+
+
+import os
+import subprocess
+import traceback
+
+try:
+ import psycopg2
+ import psycopg2.extras
+except ImportError:
+ HAS_PSYCOPG2 = False
+else:
+ HAS_PSYCOPG2 = True
+
+import ansible.module_utils.postgres as pgutils
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.database import SQLParseError, pg_quote_identifier
+from ansible.module_utils.six import iteritems
+from ansible.module_utils.six.moves import shlex_quote
+from ansible.module_utils._text import to_native
+
+executed_commands = []
+
+
+class NotSupportedError(Exception):
+ pass
+
+# ===========================================
+# PostgreSQL module specific support methods.
+#
+
+
+def set_owner(cursor, db, owner):
+ query = 'ALTER DATABASE %s OWNER TO "%s"' % (
+ pg_quote_identifier(db, 'database'),
+ owner)
+ executed_commands.append(query)
+ cursor.execute(query)
+ return True
+
+
+def set_conn_limit(cursor, db, conn_limit):
+ query = "ALTER DATABASE %s CONNECTION LIMIT %s" % (
+ pg_quote_identifier(db, 'database'),
+ conn_limit)
+ executed_commands.append(query)
+ cursor.execute(query)
+ return True
+
+
+def get_encoding_id(cursor, encoding):
+ query = "SELECT pg_char_to_encoding(%(encoding)s) AS encoding_id;"
+ cursor.execute(query, {'encoding': encoding})
+ return cursor.fetchone()['encoding_id']
+
+
+def get_db_info(cursor, db):
+ query = """
+ SELECT rolname AS owner,
+ pg_encoding_to_char(encoding) AS encoding, encoding AS encoding_id,
+ datcollate AS lc_collate, datctype AS lc_ctype, pg_database.datconnlimit AS conn_limit,
+ spcname AS tablespace
+ FROM pg_database
+ JOIN pg_roles ON pg_roles.oid = pg_database.datdba
+ JOIN pg_tablespace ON pg_tablespace.oid = pg_database.dattablespace
+ WHERE datname = %(db)s
+ """
+ cursor.execute(query, {'db': db})
+ return cursor.fetchone()
+
+
+def db_exists(cursor, db):
+ query = "SELECT * FROM pg_database WHERE datname=%(db)s"
+ cursor.execute(query, {'db': db})
+ return cursor.rowcount == 1
+
+
+def db_delete(cursor, db):
+ if db_exists(cursor, db):
+ query = "DROP DATABASE %s" % pg_quote_identifier(db, 'database')
+ executed_commands.append(query)
+ cursor.execute(query)
+ return True
+ else:
+ return False
+
+
+def db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace):
+ params = dict(enc=encoding, collate=lc_collate, ctype=lc_ctype, conn_limit=conn_limit, tablespace=tablespace)
+ if not db_exists(cursor, db):
+ query_fragments = ['CREATE DATABASE %s' % pg_quote_identifier(db, 'database')]
+ if owner:
+ query_fragments.append('OWNER "%s"' % owner)
+ if template:
+ query_fragments.append('TEMPLATE %s' % pg_quote_identifier(template, 'database'))
+ if encoding:
+ query_fragments.append('ENCODING %(enc)s')
+ if lc_collate:
+ query_fragments.append('LC_COLLATE %(collate)s')
+ if lc_ctype:
+ query_fragments.append('LC_CTYPE %(ctype)s')
+ if tablespace:
+ query_fragments.append('TABLESPACE %s' % pg_quote_identifier(tablespace, 'tablespace'))
+ if conn_limit:
+ query_fragments.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit})
+ query = ' '.join(query_fragments)
+ executed_commands.append(cursor.mogrify(query, params))
+ cursor.execute(query, params)
+ return True
+ else:
+ db_info = get_db_info(cursor, db)
+ if (encoding and get_encoding_id(cursor, encoding) != db_info['encoding_id']):
+ raise NotSupportedError(
+ 'Changing database encoding is not supported. '
+ 'Current encoding: %s' % db_info['encoding']
+ )
+ elif lc_collate and lc_collate != db_info['lc_collate']:
+ raise NotSupportedError(
+ 'Changing LC_COLLATE is not supported. '
+ 'Current LC_COLLATE: %s' % db_info['lc_collate']
+ )
+ elif lc_ctype and lc_ctype != db_info['lc_ctype']:
+ raise NotSupportedError(
+ 'Changing LC_CTYPE is not supported.'
+ 'Current LC_CTYPE: %s' % db_info['lc_ctype']
+ )
+ else:
+ changed = False
+
+ if owner and owner != db_info['owner']:
+ changed = set_owner(cursor, db, owner)
+
+ if conn_limit and conn_limit != str(db_info['conn_limit']):
+ changed = set_conn_limit(cursor, db, conn_limit)
+
+ if tablespace and tablespace != db_info['tablespace']:
+ changed = set_tablespace(cursor, db, tablespace)
+
+ return changed
+
+
+def db_matches(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace):
+ if not db_exists(cursor, db):
+ return False
+ else:
+ db_info = get_db_info(cursor, db)
+ if (encoding and get_encoding_id(cursor, encoding) != db_info['encoding_id']):
+ return False
+ elif lc_collate and lc_collate != db_info['lc_collate']:
+ return False
+ elif lc_ctype and lc_ctype != db_info['lc_ctype']:
+ return False
+ elif owner and owner != db_info['owner']:
+ return False
+ elif conn_limit and conn_limit != str(db_info['conn_limit']):
+ return False
+ elif tablespace and tablespace != db_info['tablespace']:
+ return False
+ else:
+ return True
+
+
+def db_dump(module, target, target_opts="",
+ db=None,
+ dump_extra_args=None,
+ user=None,
+ password=None,
+ host=None,
+ port=None,
+ **kw):
+
+ flags = login_flags(db, host, port, user, db_prefix=False)
+ cmd = module.get_bin_path('pg_dump', True)
+ comp_prog_path = None
+
+ if os.path.splitext(target)[-1] == '.tar':
+ flags.append(' --format=t')
+ elif os.path.splitext(target)[-1] == '.pgc':
+ flags.append(' --format=c')
+ if os.path.splitext(target)[-1] == '.gz':
+ if module.get_bin_path('pigz'):
+ comp_prog_path = module.get_bin_path('pigz', True)
+ else:
+ comp_prog_path = module.get_bin_path('gzip', True)
+ elif os.path.splitext(target)[-1] == '.bz2':
+ comp_prog_path = module.get_bin_path('bzip2', True)
+ elif os.path.splitext(target)[-1] == '.xz':
+ comp_prog_path = module.get_bin_path('xz', True)
+
+ cmd += "".join(flags)
+
+ if dump_extra_args:
+ cmd += " {0} ".format(dump_extra_args)
+
+ if target_opts:
+ cmd += " {0} ".format(target_opts)
+
+ if comp_prog_path:
+ # Use a fifo to be notified of an error in pg_dump
+ # Using shell pipe has no way to return the code of the first command
+ # in a portable way.
+ fifo = os.path.join(module.tmpdir, 'pg_fifo')
+ os.mkfifo(fifo)
+ cmd = '{1} <{3} > {2} & {0} >{3}'.format(cmd, comp_prog_path, shlex_quote(target), fifo)
+ else:
+ cmd = '{0} > {1}'.format(cmd, shlex_quote(target))
+
+ return do_with_password(module, cmd, password)
+
+
+def db_restore(module, target, target_opts="",
+ db=None,
+ user=None,
+ password=None,
+ host=None,
+ port=None,
+ **kw):
+
+ flags = login_flags(db, host, port, user)
+ comp_prog_path = None
+ cmd = module.get_bin_path('psql', True)
+
+ if os.path.splitext(target)[-1] == '.sql':
+ flags.append(' --file={0}'.format(target))
+
+ elif os.path.splitext(target)[-1] == '.tar':
+ flags.append(' --format=Tar')
+ cmd = module.get_bin_path('pg_restore', True)
+
+ elif os.path.splitext(target)[-1] == '.pgc':
+ flags.append(' --format=Custom')
+ cmd = module.get_bin_path('pg_restore', True)
+
+ elif os.path.splitext(target)[-1] == '.gz':
+ comp_prog_path = module.get_bin_path('zcat', True)
+
+ elif os.path.splitext(target)[-1] == '.bz2':
+ comp_prog_path = module.get_bin_path('bzcat', True)
+
+ elif os.path.splitext(target)[-1] == '.xz':
+ comp_prog_path = module.get_bin_path('xzcat', True)
+
+ cmd += "".join(flags)
+ if target_opts:
+ cmd += " {0} ".format(target_opts)
+
+ if comp_prog_path:
+ env = os.environ.copy()
+ if password:
+ env = {"PGPASSWORD": password}
+ p1 = subprocess.Popen([comp_prog_path, target], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ p2 = subprocess.Popen(cmd, stdin=p1.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=env)
+ (stdout2, stderr2) = p2.communicate()
+ p1.stdout.close()
+ p1.wait()
+ if p1.returncode != 0:
+ stderr1 = p1.stderr.read()
+ return p1.returncode, '', stderr1, 'cmd: ****'
+ else:
+ return p2.returncode, '', stderr2, 'cmd: ****'
+ else:
+ cmd = '{0} < {1}'.format(cmd, shlex_quote(target))
+
+ return do_with_password(module, cmd, password)
+
+
+def login_flags(db, host, port, user, db_prefix=True):
+ """
+ returns a list of connection argument strings each prefixed
+ with a space and quoted where necessary to later be combined
+ in a single shell string with `"".join(rv)`
+
+ db_prefix determines if "--dbname" is prefixed to the db argument,
+ since the argument was introduced in 9.3.
+ """
+ flags = []
+ if db:
+ if db_prefix:
+ flags.append(' --dbname={0}'.format(shlex_quote(db)))
+ else:
+ flags.append(' {0}'.format(shlex_quote(db)))
+ if host:
+ flags.append(' --host={0}'.format(host))
+ if port:
+ flags.append(' --port={0}'.format(port))
+ if user:
+ flags.append(' --username={0}'.format(user))
+ return flags
+
+
+def do_with_password(module, cmd, password):
+ env = {}
+ if password:
+ env = {"PGPASSWORD": password}
+ executed_commands.append(cmd)
+ rc, stderr, stdout = module.run_command(cmd, use_unsafe_shell=True, environ_update=env)
+ return rc, stderr, stdout, cmd
+
+
+def set_tablespace(cursor, db, tablespace):
+ query = "ALTER DATABASE %s SET TABLESPACE %s" % (
+ pg_quote_identifier(db, 'database'),
+ pg_quote_identifier(tablespace, 'tablespace'))
+ executed_commands.append(query)
+ cursor.execute(query)
+ return True
+
+# ===========================================
+# Module execution.
+#
+
+
+def main():
+ argument_spec = pgutils.postgres_common_argument_spec()
+ argument_spec.update(
+ db=dict(type='str', required=True, aliases=['name']),
+ owner=dict(type='str', default=''),
+ template=dict(type='str', default=''),
+ encoding=dict(type='str', default=''),
+ lc_collate=dict(type='str', default=''),
+ lc_ctype=dict(type='str', default=''),
+ state=dict(type='str', default='present', choices=['absent', 'dump', 'present', 'restore']),
+ target=dict(type='path', default=''),
+ target_opts=dict(type='str', default=''),
+ maintenance_db=dict(type='str', default="postgres"),
+ session_role=dict(type='str'),
+ conn_limit=dict(type='str', default=''),
+ tablespace=dict(type='path', default=''),
+ dump_extra_args=dict(type='str', default=None),
+ )
+
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ supports_check_mode=True
+ )
+
+ db = module.params["db"]
+ owner = module.params["owner"]
+ template = module.params["template"]
+ encoding = module.params["encoding"]
+ lc_collate = module.params["lc_collate"]
+ lc_ctype = module.params["lc_ctype"]
+ target = module.params["target"]
+ target_opts = module.params["target_opts"]
+ state = module.params["state"]
+ changed = False
+ maintenance_db = module.params['maintenance_db']
+ session_role = module.params["session_role"]
+ conn_limit = module.params['conn_limit']
+ tablespace = module.params['tablespace']
+ dump_extra_args = module.params['dump_extra_args']
+
+ raw_connection = state in ("dump", "restore")
+
+ if not raw_connection:
+ pgutils.ensure_required_libs(module)
+
+ # To use defaults values, keyword arguments must be absent, so
+ # check which values are empty and don't include in the **kw
+ # dictionary
+ params_map = {
+ "login_host": "host",
+ "login_user": "user",
+ "login_password": "password",
+ "port": "port",
+ "ssl_mode": "sslmode",
+ "ca_cert": "sslrootcert"
+ }
+ kw = dict((params_map[k], v) for (k, v) in iteritems(module.params)
+ if k in params_map and v != '' and v is not None)
+
+ # If a login_unix_socket is specified, incorporate it here.
+ is_localhost = "host" not in kw or kw["host"] == "" or kw["host"] == "localhost"
+
+ if is_localhost and module.params["login_unix_socket"] != "":
+ kw["host"] = module.params["login_unix_socket"]
+
+ if target == "":
+ target = "{0}/{1}.sql".format(os.getcwd(), db)
+ target = os.path.expanduser(target)
+
+ if not raw_connection:
+ try:
+ db_connection = psycopg2.connect(database=maintenance_db, **kw)
+
+ # Enable autocommit so we can create databases
+ if psycopg2.__version__ >= '2.4.2':
+ db_connection.autocommit = True
+ else:
+ db_connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+ cursor = db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor)
+
+ except TypeError as e:
+ if 'sslrootcert' in e.args[0]:
+ module.fail_json(msg='Postgresql server must be at least version 8.4 to support sslrootcert. Exception: {0}'.format(to_native(e)),
+ exception=traceback.format_exc())
+ module.fail_json(msg="unable to connect to database: %s" % to_native(e), exception=traceback.format_exc())
+
+ except Exception as e:
+ module.fail_json(msg="unable to connect to database: %s" % to_native(e), exception=traceback.format_exc())
+
+ if session_role:
+ try:
+ cursor.execute('SET ROLE "%s"' % session_role)
+ except Exception as e:
+ module.fail_json(msg="Could not switch role: %s" % to_native(e), exception=traceback.format_exc())
+
+ try:
+ if module.check_mode:
+ if state == "absent":
+ changed = db_exists(cursor, db)
+ elif state == "present":
+ changed = not db_matches(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace)
+ module.exit_json(changed=changed, db=db, executed_commands=executed_commands)
+
+ if state == "absent":
+ try:
+ changed = db_delete(cursor, db)
+ except SQLParseError as e:
+ module.fail_json(msg=to_native(e), exception=traceback.format_exc())
+
+ elif state == "present":
+ try:
+ changed = db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace)
+ except SQLParseError as e:
+ module.fail_json(msg=to_native(e), exception=traceback.format_exc())
+
+ elif state in ("dump", "restore"):
+ method = state == "dump" and db_dump or db_restore
+ try:
+ if state == 'dump':
+ rc, stdout, stderr, cmd = method(module, target, target_opts, db, dump_extra_args, **kw)
+ else:
+ rc, stdout, stderr, cmd = method(module, target, target_opts, db, **kw)
+
+ if rc != 0:
+ module.fail_json(msg=stderr, stdout=stdout, rc=rc, cmd=cmd)
+ else:
+ module.exit_json(changed=True, msg=stdout, stderr=stderr, rc=rc, cmd=cmd,
+ executed_commands=executed_commands)
+ except SQLParseError as e:
+ module.fail_json(msg=to_native(e), exception=traceback.format_exc())
+
+ except NotSupportedError as e:
+ module.fail_json(msg=to_native(e), exception=traceback.format_exc())
+ except SystemExit:
+ # Avoid catching this on Python 2.4
+ raise
+ except Exception as e:
+ module.fail_json(msg="Database query failed: %s" % to_native(e), exception=traceback.format_exc())
+
+ module.exit_json(changed=changed, db=db, executed_commands=executed_commands)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/postgresql_privs.py b/test/support/integration/plugins/modules/postgresql_privs.py
new file mode 100644
index 00000000..ba8324dd
--- /dev/null
+++ b/test/support/integration/plugins/modules/postgresql_privs.py
@@ -0,0 +1,1097 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: Ansible Project
+# Copyright: (c) 2019, Tobias Birkefeld (@tcraxs) <t@craxs.de>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: postgresql_privs
+version_added: '1.2'
+short_description: Grant or revoke privileges on PostgreSQL database objects
+description:
+- Grant or revoke privileges on PostgreSQL database objects.
+- This module is basically a wrapper around most of the functionality of
+ PostgreSQL's GRANT and REVOKE statements with detection of changes
+ (GRANT/REVOKE I(privs) ON I(type) I(objs) TO/FROM I(roles)).
+options:
+ database:
+ description:
+ - Name of database to connect to.
+ required: yes
+ type: str
+ aliases:
+ - db
+ - login_db
+ state:
+ description:
+ - If C(present), the specified privileges are granted, if C(absent) they are revoked.
+ type: str
+ default: present
+ choices: [ absent, present ]
+ privs:
+ description:
+ - Comma separated list of privileges to grant/revoke.
+ type: str
+ aliases:
+ - priv
+ type:
+ description:
+ - Type of database object to set privileges on.
+ - The C(default_privs) choice is available starting at version 2.7.
+ - The C(foreign_data_wrapper) and C(foreign_server) object types are available from Ansible version '2.8'.
+ - The C(type) choice is available from Ansible version '2.10'.
+ type: str
+ default: table
+ choices: [ database, default_privs, foreign_data_wrapper, foreign_server, function,
+ group, language, table, tablespace, schema, sequence, type ]
+ objs:
+ description:
+ - Comma separated list of database objects to set privileges on.
+ - If I(type) is C(table), C(partition table), C(sequence) or C(function),
+ the special valueC(ALL_IN_SCHEMA) can be provided instead to specify all
+ database objects of type I(type) in the schema specified via I(schema).
+ (This also works with PostgreSQL < 9.0.) (C(ALL_IN_SCHEMA) is available
+ for C(function) and C(partition table) from version 2.8)
+ - If I(type) is C(database), this parameter can be omitted, in which case
+ privileges are set for the database specified via I(database).
+ - 'If I(type) is I(function), colons (":") in object names will be
+ replaced with commas (needed to specify function signatures, see examples)'
+ type: str
+ aliases:
+ - obj
+ schema:
+ description:
+ - Schema that contains the database objects specified via I(objs).
+ - May only be provided if I(type) is C(table), C(sequence), C(function), C(type),
+ or C(default_privs). Defaults to C(public) in these cases.
+ - Pay attention, for embedded types when I(type=type)
+ I(schema) can be C(pg_catalog) or C(information_schema) respectively.
+ type: str
+ roles:
+ description:
+ - Comma separated list of role (user/group) names to set permissions for.
+ - The special value C(PUBLIC) can be provided instead to set permissions
+ for the implicitly defined PUBLIC group.
+ type: str
+ required: yes
+ aliases:
+ - role
+ fail_on_role:
+ version_added: '2.8'
+ description:
+ - If C(yes), fail when target role (for whom privs need to be granted) does not exist.
+ Otherwise just warn and continue.
+ default: yes
+ type: bool
+ session_role:
+ version_added: '2.8'
+ description:
+ - Switch to session_role after connecting.
+ - The specified session_role must be a role that the current login_user is a member of.
+ - Permissions checking for SQL commands is carried out as though the session_role were the one that had logged in originally.
+ type: str
+ target_roles:
+ description:
+ - A list of existing role (user/group) names to set as the
+ default permissions for database objects subsequently created by them.
+ - Parameter I(target_roles) is only available with C(type=default_privs).
+ type: str
+ version_added: '2.8'
+ grant_option:
+ description:
+ - Whether C(role) may grant/revoke the specified privileges/group memberships to others.
+ - Set to C(no) to revoke GRANT OPTION, leave unspecified to make no changes.
+ - I(grant_option) only has an effect if I(state) is C(present).
+ type: bool
+ aliases:
+ - admin_option
+ host:
+ description:
+ - Database host address. If unspecified, connect via Unix socket.
+ type: str
+ aliases:
+ - login_host
+ port:
+ description:
+ - Database port to connect to.
+ type: int
+ default: 5432
+ aliases:
+ - login_port
+ unix_socket:
+ description:
+ - Path to a Unix domain socket for local connections.
+ type: str
+ aliases:
+ - login_unix_socket
+ login:
+ description:
+ - The username to authenticate with.
+ type: str
+ default: postgres
+ aliases:
+ - login_user
+ password:
+ description:
+ - The password to authenticate with.
+ type: str
+ aliases:
+ - login_password
+ ssl_mode:
+ description:
+ - Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated with the server.
+ - See https://www.postgresql.org/docs/current/static/libpq-ssl.html for more information on the modes.
+ - Default of C(prefer) matches libpq default.
+ type: str
+ default: prefer
+ choices: [ allow, disable, prefer, require, verify-ca, verify-full ]
+ version_added: '2.3'
+ ca_cert:
+ description:
+ - Specifies the name of a file containing SSL certificate authority (CA) certificate(s).
+ - If the file exists, the server's certificate will be verified to be signed by one of these authorities.
+ version_added: '2.3'
+ type: str
+ aliases:
+ - ssl_rootcert
+
+notes:
+- Parameters that accept comma separated lists (I(privs), I(objs), I(roles))
+ have singular alias names (I(priv), I(obj), I(role)).
+- To revoke only C(GRANT OPTION) for a specific object, set I(state) to
+ C(present) and I(grant_option) to C(no) (see examples).
+- Note that when revoking privileges from a role R, this role may still have
+ access via privileges granted to any role R is a member of including C(PUBLIC).
+- Note that when revoking privileges from a role R, you do so as the user
+ specified via I(login). If R has been granted the same privileges by
+ another user also, R can still access database objects via these privileges.
+- When revoking privileges, C(RESTRICT) is assumed (see PostgreSQL docs).
+
+seealso:
+- module: postgresql_user
+- module: postgresql_owner
+- module: postgresql_membership
+- name: PostgreSQL privileges
+ description: General information about PostgreSQL privileges.
+ link: https://www.postgresql.org/docs/current/ddl-priv.html
+- name: PostgreSQL GRANT command reference
+ description: Complete reference of the PostgreSQL GRANT command documentation.
+ link: https://www.postgresql.org/docs/current/sql-grant.html
+- name: PostgreSQL REVOKE command reference
+ description: Complete reference of the PostgreSQL REVOKE command documentation.
+ link: https://www.postgresql.org/docs/current/sql-revoke.html
+
+extends_documentation_fragment:
+- postgres
+
+author:
+- Bernhard Weitzhofer (@b6d)
+- Tobias Birkefeld (@tcraxs)
+'''
+
+EXAMPLES = r'''
+# On database "library":
+# GRANT SELECT, INSERT, UPDATE ON TABLE public.books, public.authors
+# TO librarian, reader WITH GRANT OPTION
+- name: Grant privs to librarian and reader on database library
+ postgresql_privs:
+ database: library
+ state: present
+ privs: SELECT,INSERT,UPDATE
+ type: table
+ objs: books,authors
+ schema: public
+ roles: librarian,reader
+ grant_option: yes
+
+- name: Same as above leveraging default values
+ postgresql_privs:
+ db: library
+ privs: SELECT,INSERT,UPDATE
+ objs: books,authors
+ roles: librarian,reader
+ grant_option: yes
+
+# REVOKE GRANT OPTION FOR INSERT ON TABLE books FROM reader
+# Note that role "reader" will be *granted* INSERT privilege itself if this
+# isn't already the case (since state: present).
+- name: Revoke privs from reader
+ postgresql_privs:
+ db: library
+ state: present
+ priv: INSERT
+ obj: books
+ role: reader
+ grant_option: no
+
+# "public" is the default schema. This also works for PostgreSQL 8.x.
+- name: REVOKE INSERT, UPDATE ON ALL TABLES IN SCHEMA public FROM reader
+ postgresql_privs:
+ db: library
+ state: absent
+ privs: INSERT,UPDATE
+ objs: ALL_IN_SCHEMA
+ role: reader
+
+- name: GRANT ALL PRIVILEGES ON SCHEMA public, math TO librarian
+ postgresql_privs:
+ db: library
+ privs: ALL
+ type: schema
+ objs: public,math
+ role: librarian
+
+# Note the separation of arguments with colons.
+- name: GRANT ALL PRIVILEGES ON FUNCTION math.add(int, int) TO librarian, reader
+ postgresql_privs:
+ db: library
+ privs: ALL
+ type: function
+ obj: add(int:int)
+ schema: math
+ roles: librarian,reader
+
+# Note that group role memberships apply cluster-wide and therefore are not
+# restricted to database "library" here.
+- name: GRANT librarian, reader TO alice, bob WITH ADMIN OPTION
+ postgresql_privs:
+ db: library
+ type: group
+ objs: librarian,reader
+ roles: alice,bob
+ admin_option: yes
+
+# Note that here "db: postgres" specifies the database to connect to, not the
+# database to grant privileges on (which is specified via the "objs" param)
+- name: GRANT ALL PRIVILEGES ON DATABASE library TO librarian
+ postgresql_privs:
+ db: postgres
+ privs: ALL
+ type: database
+ obj: library
+ role: librarian
+
+# If objs is omitted for type "database", it defaults to the database
+# to which the connection is established
+- name: GRANT ALL PRIVILEGES ON DATABASE library TO librarian
+ postgresql_privs:
+ db: library
+ privs: ALL
+ type: database
+ role: librarian
+
+# Available since version 2.7
+# Objs must be set, ALL_DEFAULT to TABLES/SEQUENCES/TYPES/FUNCTIONS
+# ALL_DEFAULT works only with privs=ALL
+# For specific
+- name: ALTER DEFAULT PRIVILEGES ON DATABASE library TO librarian
+ postgresql_privs:
+ db: library
+ objs: ALL_DEFAULT
+ privs: ALL
+ type: default_privs
+ role: librarian
+ grant_option: yes
+
+# Available since version 2.7
+# Objs must be set, ALL_DEFAULT to TABLES/SEQUENCES/TYPES/FUNCTIONS
+# ALL_DEFAULT works only with privs=ALL
+# For specific
+- name: ALTER DEFAULT PRIVILEGES ON DATABASE library TO reader, step 1
+ postgresql_privs:
+ db: library
+ objs: TABLES,SEQUENCES
+ privs: SELECT
+ type: default_privs
+ role: reader
+
+- name: ALTER DEFAULT PRIVILEGES ON DATABASE library TO reader, step 2
+ postgresql_privs:
+ db: library
+ objs: TYPES
+ privs: USAGE
+ type: default_privs
+ role: reader
+
+# Available since version 2.8
+- name: GRANT ALL PRIVILEGES ON FOREIGN DATA WRAPPER fdw TO reader
+ postgresql_privs:
+ db: test
+ objs: fdw
+ privs: ALL
+ type: foreign_data_wrapper
+ role: reader
+
+# Available since version 2.10
+- name: GRANT ALL PRIVILEGES ON TYPE customtype TO reader
+ postgresql_privs:
+ db: test
+ objs: customtype
+ privs: ALL
+ type: type
+ role: reader
+
+# Available since version 2.8
+- name: GRANT ALL PRIVILEGES ON FOREIGN SERVER fdw_server TO reader
+ postgresql_privs:
+ db: test
+ objs: fdw_server
+ privs: ALL
+ type: foreign_server
+ role: reader
+
+# Available since version 2.8
+# Grant 'execute' permissions on all functions in schema 'common' to role 'caller'
+- name: GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA common TO caller
+ postgresql_privs:
+ type: function
+ state: present
+ privs: EXECUTE
+ roles: caller
+ objs: ALL_IN_SCHEMA
+ schema: common
+
+# Available since version 2.8
+# ALTER DEFAULT PRIVILEGES FOR ROLE librarian IN SCHEMA library GRANT SELECT ON TABLES TO reader
+# GRANT SELECT privileges for new TABLES objects created by librarian as
+# default to the role reader.
+# For specific
+- name: ALTER privs
+ postgresql_privs:
+ db: library
+ schema: library
+ objs: TABLES
+ privs: SELECT
+ type: default_privs
+ role: reader
+ target_roles: librarian
+
+# Available since version 2.8
+# ALTER DEFAULT PRIVILEGES FOR ROLE librarian IN SCHEMA library REVOKE SELECT ON TABLES FROM reader
+# REVOKE SELECT privileges for new TABLES objects created by librarian as
+# default from the role reader.
+# For specific
+- name: ALTER privs
+ postgresql_privs:
+ db: library
+ state: absent
+ schema: library
+ objs: TABLES
+ privs: SELECT
+ type: default_privs
+ role: reader
+ target_roles: librarian
+
+# Available since version 2.10
+- name: Grant type privileges for pg_catalog.numeric type to alice
+ postgresql_privs:
+ type: type
+ roles: alice
+ privs: ALL
+ objs: numeric
+ schema: pg_catalog
+ db: acme
+'''
+
+RETURN = r'''
+queries:
+ description: List of executed queries.
+ returned: always
+ type: list
+ sample: ['REVOKE GRANT OPTION FOR INSERT ON TABLE "books" FROM "reader";']
+ version_added: '2.8'
+'''
+
+import traceback
+
+PSYCOPG2_IMP_ERR = None
+try:
+ import psycopg2
+ import psycopg2.extensions
+except ImportError:
+ PSYCOPG2_IMP_ERR = traceback.format_exc()
+ psycopg2 = None
+
+# import module snippets
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+from ansible.module_utils.database import pg_quote_identifier
+from ansible.module_utils.postgres import postgres_common_argument_spec
+from ansible.module_utils._text import to_native
+
+VALID_PRIVS = frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE',
+ 'REFERENCES', 'TRIGGER', 'CREATE', 'CONNECT',
+ 'TEMPORARY', 'TEMP', 'EXECUTE', 'USAGE', 'ALL', 'USAGE'))
+VALID_DEFAULT_OBJS = {'TABLES': ('ALL', 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER'),
+ 'SEQUENCES': ('ALL', 'SELECT', 'UPDATE', 'USAGE'),
+ 'FUNCTIONS': ('ALL', 'EXECUTE'),
+ 'TYPES': ('ALL', 'USAGE')}
+
+executed_queries = []
+
+
+class Error(Exception):
+ pass
+
+
+def role_exists(module, cursor, rolname):
+ """Check user exists or not"""
+ query = "SELECT 1 FROM pg_roles WHERE rolname = '%s'" % rolname
+ try:
+ cursor.execute(query)
+ return cursor.rowcount > 0
+
+ except Exception as e:
+ module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e)))
+
+ return False
+
+
+# We don't have functools.partial in Python < 2.5
+def partial(f, *args, **kwargs):
+ """Partial function application"""
+
+ def g(*g_args, **g_kwargs):
+ new_kwargs = kwargs.copy()
+ new_kwargs.update(g_kwargs)
+ return f(*(args + g_args), **g_kwargs)
+
+ g.f = f
+ g.args = args
+ g.kwargs = kwargs
+ return g
+
+
+class Connection(object):
+ """Wrapper around a psycopg2 connection with some convenience methods"""
+
+ def __init__(self, params, module):
+ self.database = params.database
+ self.module = module
+ # To use defaults values, keyword arguments must be absent, so
+ # check which values are empty and don't include in the **kw
+ # dictionary
+ params_map = {
+ "host": "host",
+ "login": "user",
+ "password": "password",
+ "port": "port",
+ "database": "database",
+ "ssl_mode": "sslmode",
+ "ca_cert": "sslrootcert"
+ }
+
+ kw = dict((params_map[k], getattr(params, k)) for k in params_map
+ if getattr(params, k) != '' and getattr(params, k) is not None)
+
+ # If a unix_socket is specified, incorporate it here.
+ is_localhost = "host" not in kw or kw["host"] == "" or kw["host"] == "localhost"
+ if is_localhost and params.unix_socket != "":
+ kw["host"] = params.unix_socket
+
+ sslrootcert = params.ca_cert
+ if psycopg2.__version__ < '2.4.3' and sslrootcert is not None:
+ raise ValueError('psycopg2 must be at least 2.4.3 in order to user the ca_cert parameter')
+
+ self.connection = psycopg2.connect(**kw)
+ self.cursor = self.connection.cursor()
+
+ def commit(self):
+ self.connection.commit()
+
+ def rollback(self):
+ self.connection.rollback()
+
+ @property
+ def encoding(self):
+ """Connection encoding in Python-compatible form"""
+ return psycopg2.extensions.encodings[self.connection.encoding]
+
+ # Methods for querying database objects
+
+ # PostgreSQL < 9.0 doesn't support "ALL TABLES IN SCHEMA schema"-like
+ # phrases in GRANT or REVOKE statements, therefore alternative methods are
+ # provided here.
+
+ def schema_exists(self, schema):
+ query = """SELECT count(*)
+ FROM pg_catalog.pg_namespace WHERE nspname = %s"""
+ self.cursor.execute(query, (schema,))
+ return self.cursor.fetchone()[0] > 0
+
+ def get_all_tables_in_schema(self, schema):
+ if not self.schema_exists(schema):
+ raise Error('Schema "%s" does not exist.' % schema)
+ query = """SELECT relname
+ FROM pg_catalog.pg_class c
+ JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE nspname = %s AND relkind in ('r', 'v', 'm', 'p')"""
+ self.cursor.execute(query, (schema,))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_all_sequences_in_schema(self, schema):
+ if not self.schema_exists(schema):
+ raise Error('Schema "%s" does not exist.' % schema)
+ query = """SELECT relname
+ FROM pg_catalog.pg_class c
+ JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE nspname = %s AND relkind = 'S'"""
+ self.cursor.execute(query, (schema,))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_all_functions_in_schema(self, schema):
+ if not self.schema_exists(schema):
+ raise Error('Schema "%s" does not exist.' % schema)
+ query = """SELECT p.proname, oidvectortypes(p.proargtypes)
+ FROM pg_catalog.pg_proc p
+ JOIN pg_namespace n ON n.oid = p.pronamespace
+ WHERE nspname = %s"""
+ self.cursor.execute(query, (schema,))
+ return ["%s(%s)" % (t[0], t[1]) for t in self.cursor.fetchall()]
+
+ # Methods for getting access control lists and group membership info
+
+ # To determine whether anything has changed after granting/revoking
+ # privileges, we compare the access control lists of the specified database
+ # objects before and afterwards. Python's list/string comparison should
+ # suffice for change detection, we should not actually have to parse ACLs.
+ # The same should apply to group membership information.
+
+ def get_table_acls(self, schema, tables):
+ query = """SELECT relacl
+ FROM pg_catalog.pg_class c
+ JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE nspname = %s AND relkind in ('r','p','v','m') AND relname = ANY (%s)
+ ORDER BY relname"""
+ self.cursor.execute(query, (schema, tables))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_sequence_acls(self, schema, sequences):
+ query = """SELECT relacl
+ FROM pg_catalog.pg_class c
+ JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE nspname = %s AND relkind = 'S' AND relname = ANY (%s)
+ ORDER BY relname"""
+ self.cursor.execute(query, (schema, sequences))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_function_acls(self, schema, function_signatures):
+ funcnames = [f.split('(', 1)[0] for f in function_signatures]
+ query = """SELECT proacl
+ FROM pg_catalog.pg_proc p
+ JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
+ WHERE nspname = %s AND proname = ANY (%s)
+ ORDER BY proname, proargtypes"""
+ self.cursor.execute(query, (schema, funcnames))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_schema_acls(self, schemas):
+ query = """SELECT nspacl FROM pg_catalog.pg_namespace
+ WHERE nspname = ANY (%s) ORDER BY nspname"""
+ self.cursor.execute(query, (schemas,))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_language_acls(self, languages):
+ query = """SELECT lanacl FROM pg_catalog.pg_language
+ WHERE lanname = ANY (%s) ORDER BY lanname"""
+ self.cursor.execute(query, (languages,))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_tablespace_acls(self, tablespaces):
+ query = """SELECT spcacl FROM pg_catalog.pg_tablespace
+ WHERE spcname = ANY (%s) ORDER BY spcname"""
+ self.cursor.execute(query, (tablespaces,))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_database_acls(self, databases):
+ query = """SELECT datacl FROM pg_catalog.pg_database
+ WHERE datname = ANY (%s) ORDER BY datname"""
+ self.cursor.execute(query, (databases,))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_group_memberships(self, groups):
+ query = """SELECT roleid, grantor, member, admin_option
+ FROM pg_catalog.pg_auth_members am
+ JOIN pg_catalog.pg_roles r ON r.oid = am.roleid
+ WHERE r.rolname = ANY(%s)
+ ORDER BY roleid, grantor, member"""
+ self.cursor.execute(query, (groups,))
+ return self.cursor.fetchall()
+
+ def get_default_privs(self, schema, *args):
+ query = """SELECT defaclacl
+ FROM pg_default_acl a
+ JOIN pg_namespace b ON a.defaclnamespace=b.oid
+ WHERE b.nspname = %s;"""
+ self.cursor.execute(query, (schema,))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_foreign_data_wrapper_acls(self, fdws):
+ query = """SELECT fdwacl FROM pg_catalog.pg_foreign_data_wrapper
+ WHERE fdwname = ANY (%s) ORDER BY fdwname"""
+ self.cursor.execute(query, (fdws,))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_foreign_server_acls(self, fs):
+ query = """SELECT srvacl FROM pg_catalog.pg_foreign_server
+ WHERE srvname = ANY (%s) ORDER BY srvname"""
+ self.cursor.execute(query, (fs,))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ def get_type_acls(self, schema, types):
+ query = """SELECT t.typacl FROM pg_catalog.pg_type t
+ JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
+ WHERE n.nspname = %s AND t.typname = ANY (%s) ORDER BY typname"""
+ self.cursor.execute(query, (schema, types))
+ return [t[0] for t in self.cursor.fetchall()]
+
+ # Manipulating privileges
+
+ def manipulate_privs(self, obj_type, privs, objs, roles, target_roles,
+ state, grant_option, schema_qualifier=None, fail_on_role=True):
+ """Manipulate database object privileges.
+
+ :param obj_type: Type of database object to grant/revoke
+ privileges for.
+ :param privs: Either a list of privileges to grant/revoke
+ or None if type is "group".
+ :param objs: List of database objects to grant/revoke
+ privileges for.
+ :param roles: Either a list of role names or "PUBLIC"
+ for the implicitly defined "PUBLIC" group
+ :param target_roles: List of role names to grant/revoke
+ default privileges as.
+ :param state: "present" to grant privileges, "absent" to revoke.
+ :param grant_option: Only for state "present": If True, set
+ grant/admin option. If False, revoke it.
+ If None, don't change grant option.
+ :param schema_qualifier: Some object types ("TABLE", "SEQUENCE",
+ "FUNCTION") must be qualified by schema.
+ Ignored for other Types.
+ """
+ # get_status: function to get current status
+ if obj_type == 'table':
+ get_status = partial(self.get_table_acls, schema_qualifier)
+ elif obj_type == 'sequence':
+ get_status = partial(self.get_sequence_acls, schema_qualifier)
+ elif obj_type == 'function':
+ get_status = partial(self.get_function_acls, schema_qualifier)
+ elif obj_type == 'schema':
+ get_status = self.get_schema_acls
+ elif obj_type == 'language':
+ get_status = self.get_language_acls
+ elif obj_type == 'tablespace':
+ get_status = self.get_tablespace_acls
+ elif obj_type == 'database':
+ get_status = self.get_database_acls
+ elif obj_type == 'group':
+ get_status = self.get_group_memberships
+ elif obj_type == 'default_privs':
+ get_status = partial(self.get_default_privs, schema_qualifier)
+ elif obj_type == 'foreign_data_wrapper':
+ get_status = self.get_foreign_data_wrapper_acls
+ elif obj_type == 'foreign_server':
+ get_status = self.get_foreign_server_acls
+ elif obj_type == 'type':
+ get_status = partial(self.get_type_acls, schema_qualifier)
+ else:
+ raise Error('Unsupported database object type "%s".' % obj_type)
+
+ # Return False (nothing has changed) if there are no objs to work on.
+ if not objs:
+ return False
+
+ # obj_ids: quoted db object identifiers (sometimes schema-qualified)
+ if obj_type == 'function':
+ obj_ids = []
+ for obj in objs:
+ try:
+ f, args = obj.split('(', 1)
+ except Exception:
+ raise Error('Illegal function signature: "%s".' % obj)
+ obj_ids.append('"%s"."%s"(%s' % (schema_qualifier, f, args))
+ elif obj_type in ['table', 'sequence', 'type']:
+ obj_ids = ['"%s"."%s"' % (schema_qualifier, o) for o in objs]
+ else:
+ obj_ids = ['"%s"' % o for o in objs]
+
+ # set_what: SQL-fragment specifying what to set for the target roles:
+ # Either group membership or privileges on objects of a certain type
+ if obj_type == 'group':
+ set_what = ','.join('"%s"' % i for i in obj_ids)
+ elif obj_type == 'default_privs':
+ # We don't want privs to be quoted here
+ set_what = ','.join(privs)
+ else:
+ # function types are already quoted above
+ if obj_type != 'function':
+ obj_ids = [pg_quote_identifier(i, 'table') for i in obj_ids]
+ # Note: obj_type has been checked against a set of string literals
+ # and privs was escaped when it was parsed
+ # Note: Underscores are replaced with spaces to support multi-word obj_type
+ set_what = '%s ON %s %s' % (','.join(privs), obj_type.replace('_', ' '),
+ ','.join(obj_ids))
+
+ # for_whom: SQL-fragment specifying for whom to set the above
+ if roles == 'PUBLIC':
+ for_whom = 'PUBLIC'
+ else:
+ for_whom = []
+ for r in roles:
+ if not role_exists(self.module, self.cursor, r):
+ if fail_on_role:
+ self.module.fail_json(msg="Role '%s' does not exist" % r.strip())
+
+ else:
+ self.module.warn("Role '%s' does not exist, pass it" % r.strip())
+ else:
+ for_whom.append('"%s"' % r)
+
+ if not for_whom:
+ return False
+
+ for_whom = ','.join(for_whom)
+
+ # as_who:
+ as_who = None
+ if target_roles:
+ as_who = ','.join('"%s"' % r for r in target_roles)
+
+ status_before = get_status(objs)
+
+ query = QueryBuilder(state) \
+ .for_objtype(obj_type) \
+ .with_grant_option(grant_option) \
+ .for_whom(for_whom) \
+ .as_who(as_who) \
+ .for_schema(schema_qualifier) \
+ .set_what(set_what) \
+ .for_objs(objs) \
+ .build()
+
+ executed_queries.append(query)
+ self.cursor.execute(query)
+ status_after = get_status(objs)
+
+ def nonesorted(e):
+ # For python 3+ that can fail trying
+ # to compare NoneType elements by sort method.
+ if e is None:
+ return ''
+ return e
+
+ status_before.sort(key=nonesorted)
+ status_after.sort(key=nonesorted)
+ return status_before != status_after
+
+
+class QueryBuilder(object):
+ def __init__(self, state):
+ self._grant_option = None
+ self._for_whom = None
+ self._as_who = None
+ self._set_what = None
+ self._obj_type = None
+ self._state = state
+ self._schema = None
+ self._objs = None
+ self.query = []
+
+ def for_objs(self, objs):
+ self._objs = objs
+ return self
+
+ def for_schema(self, schema):
+ self._schema = schema
+ return self
+
+ def with_grant_option(self, option):
+ self._grant_option = option
+ return self
+
+ def for_whom(self, who):
+ self._for_whom = who
+ return self
+
+ def as_who(self, target_roles):
+ self._as_who = target_roles
+ return self
+
+ def set_what(self, what):
+ self._set_what = what
+ return self
+
+ def for_objtype(self, objtype):
+ self._obj_type = objtype
+ return self
+
+ def build(self):
+ if self._state == 'present':
+ self.build_present()
+ elif self._state == 'absent':
+ self.build_absent()
+ else:
+ self.build_absent()
+ return '\n'.join(self.query)
+
+ def add_default_revoke(self):
+ for obj in self._objs:
+ if self._as_who:
+ self.query.append(
+ 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} REVOKE ALL ON {2} FROM {3};'.format(self._as_who,
+ self._schema, obj,
+ self._for_whom))
+ else:
+ self.query.append(
+ 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} REVOKE ALL ON {1} FROM {2};'.format(self._schema, obj,
+ self._for_whom))
+
+ def add_grant_option(self):
+ if self._grant_option:
+ if self._obj_type == 'group':
+ self.query[-1] += ' WITH ADMIN OPTION;'
+ else:
+ self.query[-1] += ' WITH GRANT OPTION;'
+ else:
+ self.query[-1] += ';'
+ if self._obj_type == 'group':
+ self.query.append('REVOKE ADMIN OPTION FOR {0} FROM {1};'.format(self._set_what, self._for_whom))
+ elif not self._obj_type == 'default_privs':
+ self.query.append('REVOKE GRANT OPTION FOR {0} FROM {1};'.format(self._set_what, self._for_whom))
+
+ def add_default_priv(self):
+ for obj in self._objs:
+ if self._as_who:
+ self.query.append(
+ 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} GRANT {2} ON {3} TO {4}'.format(self._as_who,
+ self._schema,
+ self._set_what,
+ obj,
+ self._for_whom))
+ else:
+ self.query.append(
+ 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} GRANT {1} ON {2} TO {3}'.format(self._schema,
+ self._set_what,
+ obj,
+ self._for_whom))
+ self.add_grant_option()
+ if self._as_who:
+ self.query.append(
+ 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} GRANT USAGE ON TYPES TO {2}'.format(self._as_who,
+ self._schema,
+ self._for_whom))
+ else:
+ self.query.append(
+ 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} GRANT USAGE ON TYPES TO {1}'.format(self._schema, self._for_whom))
+ self.add_grant_option()
+
+ def build_present(self):
+ if self._obj_type == 'default_privs':
+ self.add_default_revoke()
+ self.add_default_priv()
+ else:
+ self.query.append('GRANT {0} TO {1}'.format(self._set_what, self._for_whom))
+ self.add_grant_option()
+
+ def build_absent(self):
+ if self._obj_type == 'default_privs':
+ self.query = []
+ for obj in ['TABLES', 'SEQUENCES', 'TYPES']:
+ if self._as_who:
+ self.query.append(
+ 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} REVOKE ALL ON {2} FROM {3};'.format(self._as_who,
+ self._schema, obj,
+ self._for_whom))
+ else:
+ self.query.append(
+ 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} REVOKE ALL ON {1} FROM {2};'.format(self._schema, obj,
+ self._for_whom))
+ else:
+ self.query.append('REVOKE {0} FROM {1};'.format(self._set_what, self._for_whom))
+
+
+def main():
+ argument_spec = postgres_common_argument_spec()
+ argument_spec.update(
+ database=dict(required=True, aliases=['db', 'login_db']),
+ state=dict(default='present', choices=['present', 'absent']),
+ privs=dict(required=False, aliases=['priv']),
+ type=dict(default='table',
+ choices=['table',
+ 'sequence',
+ 'function',
+ 'database',
+ 'schema',
+ 'language',
+ 'tablespace',
+ 'group',
+ 'default_privs',
+ 'foreign_data_wrapper',
+ 'foreign_server',
+ 'type', ]),
+ objs=dict(required=False, aliases=['obj']),
+ schema=dict(required=False),
+ roles=dict(required=True, aliases=['role']),
+ session_role=dict(required=False),
+ target_roles=dict(required=False),
+ grant_option=dict(required=False, type='bool',
+ aliases=['admin_option']),
+ host=dict(default='', aliases=['login_host']),
+ unix_socket=dict(default='', aliases=['login_unix_socket']),
+ login=dict(default='postgres', aliases=['login_user']),
+ password=dict(default='', aliases=['login_password'], no_log=True),
+ fail_on_role=dict(type='bool', default=True),
+ )
+
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ supports_check_mode=True,
+ )
+
+ fail_on_role = module.params['fail_on_role']
+
+ # Create type object as namespace for module params
+ p = type('Params', (), module.params)
+ # param "schema": default, allowed depends on param "type"
+ if p.type in ['table', 'sequence', 'function', 'type', 'default_privs']:
+ p.schema = p.schema or 'public'
+ elif p.schema:
+ module.fail_json(msg='Argument "schema" is not allowed '
+ 'for type "%s".' % p.type)
+
+ # param "objs": default, required depends on param "type"
+ if p.type == 'database':
+ p.objs = p.objs or p.database
+ elif not p.objs:
+ module.fail_json(msg='Argument "objs" is required '
+ 'for type "%s".' % p.type)
+
+ # param "privs": allowed, required depends on param "type"
+ if p.type == 'group':
+ if p.privs:
+ module.fail_json(msg='Argument "privs" is not allowed '
+ 'for type "group".')
+ elif not p.privs:
+ module.fail_json(msg='Argument "privs" is required '
+ 'for type "%s".' % p.type)
+
+ # Connect to Database
+ if not psycopg2:
+ module.fail_json(msg=missing_required_lib('psycopg2'), exception=PSYCOPG2_IMP_ERR)
+ try:
+ conn = Connection(p, module)
+ except psycopg2.Error as e:
+ module.fail_json(msg='Could not connect to database: %s' % to_native(e), exception=traceback.format_exc())
+ except TypeError as e:
+ if 'sslrootcert' in e.args[0]:
+ module.fail_json(msg='Postgresql server must be at least version 8.4 to support sslrootcert')
+ module.fail_json(msg="unable to connect to database: %s" % to_native(e), exception=traceback.format_exc())
+ except ValueError as e:
+ # We raise this when the psycopg library is too old
+ module.fail_json(msg=to_native(e))
+
+ if p.session_role:
+ try:
+ conn.cursor.execute('SET ROLE "%s"' % p.session_role)
+ except Exception as e:
+ module.fail_json(msg="Could not switch to role %s: %s" % (p.session_role, to_native(e)), exception=traceback.format_exc())
+
+ try:
+ # privs
+ if p.privs:
+ privs = frozenset(pr.upper() for pr in p.privs.split(','))
+ if not privs.issubset(VALID_PRIVS):
+ module.fail_json(msg='Invalid privileges specified: %s' % privs.difference(VALID_PRIVS))
+ else:
+ privs = None
+ # objs:
+ if p.type == 'table' and p.objs == 'ALL_IN_SCHEMA':
+ objs = conn.get_all_tables_in_schema(p.schema)
+ elif p.type == 'sequence' and p.objs == 'ALL_IN_SCHEMA':
+ objs = conn.get_all_sequences_in_schema(p.schema)
+ elif p.type == 'function' and p.objs == 'ALL_IN_SCHEMA':
+ objs = conn.get_all_functions_in_schema(p.schema)
+ elif p.type == 'default_privs':
+ if p.objs == 'ALL_DEFAULT':
+ objs = frozenset(VALID_DEFAULT_OBJS.keys())
+ else:
+ objs = frozenset(obj.upper() for obj in p.objs.split(','))
+ if not objs.issubset(VALID_DEFAULT_OBJS):
+ module.fail_json(
+ msg='Invalid Object set specified: %s' % objs.difference(VALID_DEFAULT_OBJS.keys()))
+ # Again, do we have valid privs specified for object type:
+ valid_objects_for_priv = frozenset(obj for obj in objs if privs.issubset(VALID_DEFAULT_OBJS[obj]))
+ if not valid_objects_for_priv == objs:
+ module.fail_json(
+ msg='Invalid priv specified. Valid object for priv: {0}. Objects: {1}'.format(
+ valid_objects_for_priv, objs))
+ else:
+ objs = p.objs.split(',')
+
+ # function signatures are encoded using ':' to separate args
+ if p.type == 'function':
+ objs = [obj.replace(':', ',') for obj in objs]
+
+ # roles
+ if p.roles == 'PUBLIC':
+ roles = 'PUBLIC'
+ else:
+ roles = p.roles.split(',')
+
+ if len(roles) == 1 and not role_exists(module, conn.cursor, roles[0]):
+ module.exit_json(changed=False)
+
+ if fail_on_role:
+ module.fail_json(msg="Role '%s' does not exist" % roles[0].strip())
+
+ else:
+ module.warn("Role '%s' does not exist, nothing to do" % roles[0].strip())
+
+ # check if target_roles is set with type: default_privs
+ if p.target_roles and not p.type == 'default_privs':
+ module.warn('"target_roles" will be ignored '
+ 'Argument "type: default_privs" is required for usage of "target_roles".')
+
+ # target roles
+ if p.target_roles:
+ target_roles = p.target_roles.split(',')
+ else:
+ target_roles = None
+
+ changed = conn.manipulate_privs(
+ obj_type=p.type,
+ privs=privs,
+ objs=objs,
+ roles=roles,
+ target_roles=target_roles,
+ state=p.state,
+ grant_option=p.grant_option,
+ schema_qualifier=p.schema,
+ fail_on_role=fail_on_role,
+ )
+
+ except Error as e:
+ conn.rollback()
+ module.fail_json(msg=e.message, exception=traceback.format_exc())
+
+ except psycopg2.Error as e:
+ conn.rollback()
+ module.fail_json(msg=to_native(e.message))
+
+ if module.check_mode:
+ conn.rollback()
+ else:
+ conn.commit()
+ module.exit_json(changed=changed, queries=executed_queries)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/postgresql_query.py b/test/support/integration/plugins/modules/postgresql_query.py
new file mode 100644
index 00000000..18d63e33
--- /dev/null
+++ b/test/support/integration/plugins/modules/postgresql_query.py
@@ -0,0 +1,364 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2017, Felix Archambault
+# Copyright: (c) 2019, Andrew Klychkov (@Andersson007) <aaklychkov@mail.ru>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+ANSIBLE_METADATA = {
+ 'metadata_version': '1.1',
+ 'supported_by': 'community',
+ 'status': ['preview']
+}
+
+DOCUMENTATION = r'''
+---
+module: postgresql_query
+short_description: Run PostgreSQL queries
+description:
+- Runs arbitrary PostgreSQL queries.
+- Can run queries from SQL script files.
+- Does not run against backup files. Use M(postgresql_db) with I(state=restore)
+ to run queries on files made by pg_dump/pg_dumpall utilities.
+version_added: '2.8'
+options:
+ query:
+ description:
+ - SQL query to run. Variables can be escaped with psycopg2 syntax
+ U(http://initd.org/psycopg/docs/usage.html).
+ type: str
+ positional_args:
+ description:
+ - List of values to be passed as positional arguments to the query.
+ When the value is a list, it will be converted to PostgreSQL array.
+ - Mutually exclusive with I(named_args).
+ type: list
+ elements: raw
+ named_args:
+ description:
+ - Dictionary of key-value arguments to pass to the query.
+ When the value is a list, it will be converted to PostgreSQL array.
+ - Mutually exclusive with I(positional_args).
+ type: dict
+ path_to_script:
+ description:
+ - Path to SQL script on the remote host.
+ - Returns result of the last query in the script.
+ - Mutually exclusive with I(query).
+ type: path
+ session_role:
+ description:
+ - Switch to session_role after connecting. The specified session_role must
+ be a role that the current login_user is a member of.
+ - Permissions checking for SQL commands is carried out as though
+ the session_role were the one that had logged in originally.
+ type: str
+ db:
+ description:
+ - Name of database to connect to and run queries against.
+ type: str
+ aliases:
+ - login_db
+ autocommit:
+ description:
+ - Execute in autocommit mode when the query can't be run inside a transaction block
+ (e.g., VACUUM).
+ - Mutually exclusive with I(check_mode).
+ type: bool
+ default: no
+ version_added: '2.9'
+ encoding:
+ description:
+ - Set the client encoding for the current session (e.g. C(UTF-8)).
+ - The default is the encoding defined by the database.
+ type: str
+ version_added: '2.10'
+seealso:
+- module: postgresql_db
+author:
+- Felix Archambault (@archf)
+- Andrew Klychkov (@Andersson007)
+- Will Rouesnel (@wrouesnel)
+extends_documentation_fragment: postgres
+'''
+
+EXAMPLES = r'''
+- name: Simple select query to acme db
+ postgresql_query:
+ db: acme
+ query: SELECT version()
+
+- name: Select query to db acme with positional arguments and non-default credentials
+ postgresql_query:
+ db: acme
+ login_user: django
+ login_password: mysecretpass
+ query: SELECT * FROM acme WHERE id = %s AND story = %s
+ positional_args:
+ - 1
+ - test
+
+- name: Select query to test_db with named_args
+ postgresql_query:
+ db: test_db
+ query: SELECT * FROM test WHERE id = %(id_val)s AND story = %(story_val)s
+ named_args:
+ id_val: 1
+ story_val: test
+
+- name: Insert query to test_table in db test_db
+ postgresql_query:
+ db: test_db
+ query: INSERT INTO test_table (id, story) VALUES (2, 'my_long_story')
+
+- name: Run queries from SQL script using UTF-8 client encoding for session
+ postgresql_query:
+ db: test_db
+ path_to_script: /var/lib/pgsql/test.sql
+ positional_args:
+ - 1
+ encoding: UTF-8
+
+- name: Example of using autocommit parameter
+ postgresql_query:
+ db: test_db
+ query: VACUUM
+ autocommit: yes
+
+- name: >
+ Insert data to the column of array type using positional_args.
+ Note that we use quotes here, the same as for passing JSON, etc.
+ postgresql_query:
+ query: INSERT INTO test_table (array_column) VALUES (%s)
+ positional_args:
+ - '{1,2,3}'
+
+# Pass list and string vars as positional_args
+- name: Set vars
+ set_fact:
+ my_list:
+ - 1
+ - 2
+ - 3
+ my_arr: '{1, 2, 3}'
+
+- name: Select from test table by passing positional_args as arrays
+ postgresql_query:
+ query: SELECT * FROM test_array_table WHERE arr_col1 = %s AND arr_col2 = %s
+ positional_args:
+ - '{{ my_list }}'
+ - '{{ my_arr|string }}'
+'''
+
+RETURN = r'''
+query:
+ description: Query that was tried to be executed.
+ returned: always
+ type: str
+ sample: 'SELECT * FROM bar'
+statusmessage:
+ description: Attribute containing the message returned by the command.
+ returned: always
+ type: str
+ sample: 'INSERT 0 1'
+query_result:
+ description:
+ - List of dictionaries in column:value form representing returned rows.
+ returned: changed
+ type: list
+ sample: [{"Column": "Value1"},{"Column": "Value2"}]
+rowcount:
+ description: Number of affected rows.
+ returned: changed
+ type: int
+ sample: 5
+'''
+
+try:
+ from psycopg2 import ProgrammingError as Psycopg2ProgrammingError
+ from psycopg2.extras import DictCursor
+except ImportError:
+ # it is needed for checking 'no result to fetch' in main(),
+ # psycopg2 availability will be checked by connect_to_db() into
+ # ansible.module_utils.postgres
+ pass
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.postgres import (
+ connect_to_db,
+ get_conn_params,
+ postgres_common_argument_spec,
+)
+from ansible.module_utils._text import to_native
+from ansible.module_utils.six import iteritems
+
+
+# ===========================================
+# Module execution.
+#
+
+def list_to_pg_array(elem):
+ """Convert the passed list to PostgreSQL array
+ represented as a string.
+
+ Args:
+ elem (list): List that needs to be converted.
+
+ Returns:
+ elem (str): String representation of PostgreSQL array.
+ """
+ elem = str(elem).strip('[]')
+ elem = '{' + elem + '}'
+ return elem
+
+
+def convert_elements_to_pg_arrays(obj):
+ """Convert list elements of the passed object
+ to PostgreSQL arrays represented as strings.
+
+ Args:
+ obj (dict or list): Object whose elements need to be converted.
+
+ Returns:
+ obj (dict or list): Object with converted elements.
+ """
+ if isinstance(obj, dict):
+ for (key, elem) in iteritems(obj):
+ if isinstance(elem, list):
+ obj[key] = list_to_pg_array(elem)
+
+ elif isinstance(obj, list):
+ for i, elem in enumerate(obj):
+ if isinstance(elem, list):
+ obj[i] = list_to_pg_array(elem)
+
+ return obj
+
+
+def main():
+ argument_spec = postgres_common_argument_spec()
+ argument_spec.update(
+ query=dict(type='str'),
+ db=dict(type='str', aliases=['login_db']),
+ positional_args=dict(type='list', elements='raw'),
+ named_args=dict(type='dict'),
+ session_role=dict(type='str'),
+ path_to_script=dict(type='path'),
+ autocommit=dict(type='bool', default=False),
+ encoding=dict(type='str'),
+ )
+
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ mutually_exclusive=(('positional_args', 'named_args'),),
+ supports_check_mode=True,
+ )
+
+ query = module.params["query"]
+ positional_args = module.params["positional_args"]
+ named_args = module.params["named_args"]
+ path_to_script = module.params["path_to_script"]
+ autocommit = module.params["autocommit"]
+ encoding = module.params["encoding"]
+
+ if autocommit and module.check_mode:
+ module.fail_json(msg="Using autocommit is mutually exclusive with check_mode")
+
+ if path_to_script and query:
+ module.fail_json(msg="path_to_script is mutually exclusive with query")
+
+ if positional_args:
+ positional_args = convert_elements_to_pg_arrays(positional_args)
+
+ elif named_args:
+ named_args = convert_elements_to_pg_arrays(named_args)
+
+ if path_to_script:
+ try:
+ with open(path_to_script, 'rb') as f:
+ query = to_native(f.read())
+ except Exception as e:
+ module.fail_json(msg="Cannot read file '%s' : %s" % (path_to_script, to_native(e)))
+
+ conn_params = get_conn_params(module, module.params)
+ db_connection = connect_to_db(module, conn_params, autocommit=autocommit)
+ if encoding is not None:
+ db_connection.set_client_encoding(encoding)
+ cursor = db_connection.cursor(cursor_factory=DictCursor)
+
+ # Prepare args:
+ if module.params.get("positional_args"):
+ arguments = module.params["positional_args"]
+ elif module.params.get("named_args"):
+ arguments = module.params["named_args"]
+ else:
+ arguments = None
+
+ # Set defaults:
+ changed = False
+
+ # Execute query:
+ try:
+ cursor.execute(query, arguments)
+ except Exception as e:
+ if not autocommit:
+ db_connection.rollback()
+
+ cursor.close()
+ db_connection.close()
+ module.fail_json(msg="Cannot execute SQL '%s' %s: %s" % (query, arguments, to_native(e)))
+
+ statusmessage = cursor.statusmessage
+ rowcount = cursor.rowcount
+
+ try:
+ query_result = [dict(row) for row in cursor.fetchall()]
+ except Psycopg2ProgrammingError as e:
+ if to_native(e) == 'no results to fetch':
+ query_result = {}
+
+ except Exception as e:
+ module.fail_json(msg="Cannot fetch rows from cursor: %s" % to_native(e))
+
+ if 'SELECT' not in statusmessage:
+ if 'UPDATE' in statusmessage or 'INSERT' in statusmessage or 'DELETE' in statusmessage:
+ s = statusmessage.split()
+ if len(s) == 3:
+ if statusmessage.split()[2] != '0':
+ changed = True
+
+ elif len(s) == 2:
+ if statusmessage.split()[1] != '0':
+ changed = True
+
+ else:
+ changed = True
+
+ else:
+ changed = True
+
+ if module.check_mode:
+ db_connection.rollback()
+ else:
+ if not autocommit:
+ db_connection.commit()
+
+ kw = dict(
+ changed=changed,
+ query=cursor.query,
+ statusmessage=statusmessage,
+ query_result=query_result,
+ rowcount=rowcount if rowcount >= 0 else 0,
+ )
+
+ cursor.close()
+ db_connection.close()
+
+ module.exit_json(**kw)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/postgresql_set.py b/test/support/integration/plugins/modules/postgresql_set.py
new file mode 100644
index 00000000..cfbdae64
--- /dev/null
+++ b/test/support/integration/plugins/modules/postgresql_set.py
@@ -0,0 +1,434 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2018, Andrew Klychkov (@Andersson007) <aaklychkov@mail.ru>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {
+ 'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'
+}
+
+DOCUMENTATION = r'''
+---
+module: postgresql_set
+short_description: Change a PostgreSQL server configuration parameter
+description:
+ - Allows to change a PostgreSQL server configuration parameter.
+ - The module uses ALTER SYSTEM command and applies changes by reload server configuration.
+ - ALTER SYSTEM is used for changing server configuration parameters across the entire database cluster.
+ - It can be more convenient and safe than the traditional method of manually editing the postgresql.conf file.
+ - ALTER SYSTEM writes the given parameter setting to the $PGDATA/postgresql.auto.conf file,
+ which is read in addition to postgresql.conf.
+ - The module allows to reset parameter to boot_val (cluster initial value) by I(reset=yes) or remove parameter
+ string from postgresql.auto.conf and reload I(value=default) (for settings with postmaster context restart is required).
+ - After change you can see in the ansible output the previous and
+ the new parameter value and other information using returned values and M(debug) module.
+version_added: '2.8'
+options:
+ name:
+ description:
+ - Name of PostgreSQL server parameter.
+ type: str
+ required: true
+ value:
+ description:
+ - Parameter value to set.
+ - To remove parameter string from postgresql.auto.conf and
+ reload the server configuration you must pass I(value=default).
+ With I(value=default) the playbook always returns changed is true.
+ type: str
+ reset:
+ description:
+ - Restore parameter to initial state (boot_val). Mutually exclusive with I(value).
+ type: bool
+ default: false
+ session_role:
+ description:
+ - Switch to session_role after connecting. The specified session_role must
+ be a role that the current login_user is a member of.
+ - Permissions checking for SQL commands is carried out as though
+ the session_role were the one that had logged in originally.
+ type: str
+ db:
+ description:
+ - Name of database to connect.
+ type: str
+ aliases:
+ - login_db
+notes:
+- Supported version of PostgreSQL is 9.4 and later.
+- Pay attention, change setting with 'postmaster' context can return changed is true
+ when actually nothing changes because the same value may be presented in
+ several different form, for example, 1024MB, 1GB, etc. However in pg_settings
+ system view it can be defined like 131072 number of 8kB pages.
+ The final check of the parameter value cannot compare it because the server was
+ not restarted and the value in pg_settings is not updated yet.
+- For some parameters restart of PostgreSQL server is required.
+ See official documentation U(https://www.postgresql.org/docs/current/view-pg-settings.html).
+seealso:
+- module: postgresql_info
+- name: PostgreSQL server configuration
+ description: General information about PostgreSQL server configuration.
+ link: https://www.postgresql.org/docs/current/runtime-config.html
+- name: PostgreSQL view pg_settings reference
+ description: Complete reference of the pg_settings view documentation.
+ link: https://www.postgresql.org/docs/current/view-pg-settings.html
+- name: PostgreSQL ALTER SYSTEM command reference
+ description: Complete reference of the ALTER SYSTEM command documentation.
+ link: https://www.postgresql.org/docs/current/sql-altersystem.html
+author:
+- Andrew Klychkov (@Andersson007)
+extends_documentation_fragment: postgres
+'''
+
+EXAMPLES = r'''
+- name: Restore wal_keep_segments parameter to initial state
+ postgresql_set:
+ name: wal_keep_segments
+ reset: yes
+
+# Set work_mem parameter to 32MB and show what's been changed and restart is required or not
+# (output example: "msg": "work_mem 4MB >> 64MB restart_req: False")
+- name: Set work mem parameter
+ postgresql_set:
+ name: work_mem
+ value: 32mb
+ register: set
+
+- debug:
+ msg: "{{ set.name }} {{ set.prev_val_pretty }} >> {{ set.value_pretty }} restart_req: {{ set.restart_required }}"
+ when: set.changed
+# Ensure that the restart of PostgreSQL server must be required for some parameters.
+# In this situation you see the same parameter in prev_val and value_prettyue, but 'changed=True'
+# (If you passed the value that was different from the current server setting).
+
+- name: Set log_min_duration_statement parameter to 1 second
+ postgresql_set:
+ name: log_min_duration_statement
+ value: 1s
+
+- name: Set wal_log_hints parameter to default value (remove parameter from postgresql.auto.conf)
+ postgresql_set:
+ name: wal_log_hints
+ value: default
+'''
+
+RETURN = r'''
+name:
+ description: Name of PostgreSQL server parameter.
+ returned: always
+ type: str
+ sample: 'shared_buffers'
+restart_required:
+ description: Information about parameter current state.
+ returned: always
+ type: bool
+ sample: true
+prev_val_pretty:
+ description: Information about previous state of the parameter.
+ returned: always
+ type: str
+ sample: '4MB'
+value_pretty:
+ description: Information about current state of the parameter.
+ returned: always
+ type: str
+ sample: '64MB'
+value:
+ description:
+ - Dictionary that contains the current parameter value (at the time of playbook finish).
+ - Pay attention that for real change some parameters restart of PostgreSQL server is required.
+ - Returns the current value in the check mode.
+ returned: always
+ type: dict
+ sample: { "value": 67108864, "unit": "b" }
+context:
+ description:
+ - PostgreSQL setting context.
+ returned: always
+ type: str
+ sample: user
+'''
+
+try:
+ from psycopg2.extras import DictCursor
+except Exception:
+ # psycopg2 is checked by connect_to_db()
+ # from ansible.module_utils.postgres
+ pass
+
+from copy import deepcopy
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.postgres import (
+ connect_to_db,
+ get_conn_params,
+ postgres_common_argument_spec,
+)
+from ansible.module_utils._text import to_native
+
+PG_REQ_VER = 90400
+
+# To allow to set value like 1mb instead of 1MB, etc:
+POSSIBLE_SIZE_UNITS = ("mb", "gb", "tb")
+
+# ===========================================
+# PostgreSQL module specific support methods.
+#
+
+
+def param_get(cursor, module, name):
+ query = ("SELECT name, setting, unit, context, boot_val "
+ "FROM pg_settings WHERE name = %(name)s")
+ try:
+ cursor.execute(query, {'name': name})
+ info = cursor.fetchall()
+ cursor.execute("SHOW %s" % name)
+ val = cursor.fetchone()
+
+ except Exception as e:
+ module.fail_json(msg="Unable to get %s value due to : %s" % (name, to_native(e)))
+
+ raw_val = info[0][1]
+ unit = info[0][2]
+ context = info[0][3]
+ boot_val = info[0][4]
+
+ if val[0] == 'True':
+ val[0] = 'on'
+ elif val[0] == 'False':
+ val[0] = 'off'
+
+ if unit == 'kB':
+ if int(raw_val) > 0:
+ raw_val = int(raw_val) * 1024
+ if int(boot_val) > 0:
+ boot_val = int(boot_val) * 1024
+
+ unit = 'b'
+
+ elif unit == 'MB':
+ if int(raw_val) > 0:
+ raw_val = int(raw_val) * 1024 * 1024
+ if int(boot_val) > 0:
+ boot_val = int(boot_val) * 1024 * 1024
+
+ unit = 'b'
+
+ return (val[0], raw_val, unit, boot_val, context)
+
+
+def pretty_to_bytes(pretty_val):
+ # The function returns a value in bytes
+ # if the value contains 'B', 'kB', 'MB', 'GB', 'TB'.
+ # Otherwise it returns the passed argument.
+
+ val_in_bytes = None
+
+ if 'kB' in pretty_val:
+ num_part = int(''.join(d for d in pretty_val if d.isdigit()))
+ val_in_bytes = num_part * 1024
+
+ elif 'MB' in pretty_val.upper():
+ num_part = int(''.join(d for d in pretty_val if d.isdigit()))
+ val_in_bytes = num_part * 1024 * 1024
+
+ elif 'GB' in pretty_val.upper():
+ num_part = int(''.join(d for d in pretty_val if d.isdigit()))
+ val_in_bytes = num_part * 1024 * 1024 * 1024
+
+ elif 'TB' in pretty_val.upper():
+ num_part = int(''.join(d for d in pretty_val if d.isdigit()))
+ val_in_bytes = num_part * 1024 * 1024 * 1024 * 1024
+
+ elif 'B' in pretty_val.upper():
+ num_part = int(''.join(d for d in pretty_val if d.isdigit()))
+ val_in_bytes = num_part
+
+ else:
+ return pretty_val
+
+ return val_in_bytes
+
+
+def param_set(cursor, module, name, value, context):
+ try:
+ if str(value).lower() == 'default':
+ query = "ALTER SYSTEM SET %s = DEFAULT" % name
+ else:
+ query = "ALTER SYSTEM SET %s = '%s'" % (name, value)
+ cursor.execute(query)
+
+ if context != 'postmaster':
+ cursor.execute("SELECT pg_reload_conf()")
+
+ except Exception as e:
+ module.fail_json(msg="Unable to get %s value due to : %s" % (name, to_native(e)))
+
+ return True
+
+
+# ===========================================
+# Module execution.
+#
+
+
+def main():
+ argument_spec = postgres_common_argument_spec()
+ argument_spec.update(
+ name=dict(type='str', required=True),
+ db=dict(type='str', aliases=['login_db']),
+ value=dict(type='str'),
+ reset=dict(type='bool'),
+ session_role=dict(type='str'),
+ )
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ supports_check_mode=True,
+ )
+
+ name = module.params["name"]
+ value = module.params["value"]
+ reset = module.params["reset"]
+
+ # Allow to pass values like 1mb instead of 1MB, etc:
+ if value:
+ for unit in POSSIBLE_SIZE_UNITS:
+ if value[:-2].isdigit() and unit in value[-2:]:
+ value = value.upper()
+
+ if value and reset:
+ module.fail_json(msg="%s: value and reset params are mutually exclusive" % name)
+
+ if not value and not reset:
+ module.fail_json(msg="%s: at least one of value or reset param must be specified" % name)
+
+ conn_params = get_conn_params(module, module.params, warn_db_default=False)
+ db_connection = connect_to_db(module, conn_params, autocommit=True)
+ cursor = db_connection.cursor(cursor_factory=DictCursor)
+
+ kw = {}
+ # Check server version (needs 9.4 or later):
+ ver = db_connection.server_version
+ if ver < PG_REQ_VER:
+ module.warn("PostgreSQL is %s version but %s or later is required" % (ver, PG_REQ_VER))
+ kw = dict(
+ changed=False,
+ restart_required=False,
+ value_pretty="",
+ prev_val_pretty="",
+ value={"value": "", "unit": ""},
+ )
+ kw['name'] = name
+ db_connection.close()
+ module.exit_json(**kw)
+
+ # Set default returned values:
+ restart_required = False
+ changed = False
+ kw['name'] = name
+ kw['restart_required'] = False
+
+ # Get info about param state:
+ res = param_get(cursor, module, name)
+ current_value = res[0]
+ raw_val = res[1]
+ unit = res[2]
+ boot_val = res[3]
+ context = res[4]
+
+ if value == 'True':
+ value = 'on'
+ elif value == 'False':
+ value = 'off'
+
+ kw['prev_val_pretty'] = current_value
+ kw['value_pretty'] = deepcopy(kw['prev_val_pretty'])
+ kw['context'] = context
+
+ # Do job
+ if context == "internal":
+ module.fail_json(msg="%s: cannot be changed (internal context). See "
+ "https://www.postgresql.org/docs/current/runtime-config-preset.html" % name)
+
+ if context == "postmaster":
+ restart_required = True
+
+ # If check_mode, just compare and exit:
+ if module.check_mode:
+ if pretty_to_bytes(value) == pretty_to_bytes(current_value):
+ kw['changed'] = False
+
+ else:
+ kw['value_pretty'] = value
+ kw['changed'] = True
+
+ # Anyway returns current raw value in the check_mode:
+ kw['value'] = dict(
+ value=raw_val,
+ unit=unit,
+ )
+ kw['restart_required'] = restart_required
+ module.exit_json(**kw)
+
+ # Set param:
+ if value and value != current_value:
+ changed = param_set(cursor, module, name, value, context)
+
+ kw['value_pretty'] = value
+
+ # Reset param:
+ elif reset:
+ if raw_val == boot_val:
+ # nothing to change, exit:
+ kw['value'] = dict(
+ value=raw_val,
+ unit=unit,
+ )
+ module.exit_json(**kw)
+
+ changed = param_set(cursor, module, name, boot_val, context)
+
+ if restart_required:
+ module.warn("Restart of PostgreSQL is required for setting %s" % name)
+
+ cursor.close()
+ db_connection.close()
+
+ # Reconnect and recheck current value:
+ if context in ('sighup', 'superuser-backend', 'backend', 'superuser', 'user'):
+ db_connection = connect_to_db(module, conn_params, autocommit=True)
+ cursor = db_connection.cursor(cursor_factory=DictCursor)
+
+ res = param_get(cursor, module, name)
+ # f_ means 'final'
+ f_value = res[0]
+ f_raw_val = res[1]
+
+ if raw_val == f_raw_val:
+ changed = False
+
+ else:
+ changed = True
+
+ kw['value_pretty'] = f_value
+ kw['value'] = dict(
+ value=f_raw_val,
+ unit=unit,
+ )
+
+ cursor.close()
+ db_connection.close()
+
+ kw['changed'] = changed
+ kw['restart_required'] = restart_required
+ module.exit_json(**kw)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/postgresql_table.py b/test/support/integration/plugins/modules/postgresql_table.py
new file mode 100644
index 00000000..3bef03b0
--- /dev/null
+++ b/test/support/integration/plugins/modules/postgresql_table.py
@@ -0,0 +1,601 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2019, Andrew Klychkov (@Andersson007) <aaklychkov@mail.ru>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {
+ 'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'
+}
+
+DOCUMENTATION = r'''
+---
+module: postgresql_table
+short_description: Create, drop, or modify a PostgreSQL table
+description:
+- Allows to create, drop, rename, truncate a table, or change some table attributes.
+version_added: '2.8'
+options:
+ table:
+ description:
+ - Table name.
+ required: true
+ aliases:
+ - name
+ type: str
+ state:
+ description:
+ - The table state. I(state=absent) is mutually exclusive with I(tablespace), I(owner), I(unlogged),
+ I(like), I(including), I(columns), I(truncate), I(storage_params) and, I(rename).
+ type: str
+ default: present
+ choices: [ absent, present ]
+ tablespace:
+ description:
+ - Set a tablespace for the table.
+ required: false
+ type: str
+ owner:
+ description:
+ - Set a table owner.
+ type: str
+ unlogged:
+ description:
+ - Create an unlogged table.
+ type: bool
+ default: no
+ like:
+ description:
+ - Create a table like another table (with similar DDL).
+ Mutually exclusive with I(columns), I(rename), and I(truncate).
+ type: str
+ including:
+ description:
+ - Keywords that are used with like parameter, may be DEFAULTS, CONSTRAINTS, INDEXES, STORAGE, COMMENTS or ALL.
+ Needs I(like) specified. Mutually exclusive with I(columns), I(rename), and I(truncate).
+ type: str
+ columns:
+ description:
+ - Columns that are needed.
+ type: list
+ elements: str
+ rename:
+ description:
+ - New table name. Mutually exclusive with I(tablespace), I(owner),
+ I(unlogged), I(like), I(including), I(columns), I(truncate), and I(storage_params).
+ type: str
+ truncate:
+ description:
+ - Truncate a table. Mutually exclusive with I(tablespace), I(owner), I(unlogged),
+ I(like), I(including), I(columns), I(rename), and I(storage_params).
+ type: bool
+ default: no
+ storage_params:
+ description:
+ - Storage parameters like fillfactor, autovacuum_vacuum_treshold, etc.
+ Mutually exclusive with I(rename) and I(truncate).
+ type: list
+ elements: str
+ db:
+ description:
+ - Name of database to connect and where the table will be created.
+ type: str
+ aliases:
+ - login_db
+ session_role:
+ description:
+ - Switch to session_role after connecting.
+ The specified session_role must be a role that the current login_user is a member of.
+ - Permissions checking for SQL commands is carried out as though
+ the session_role were the one that had logged in originally.
+ type: str
+ cascade:
+ description:
+ - Automatically drop objects that depend on the table (such as views).
+ Used with I(state=absent) only.
+ type: bool
+ default: no
+ version_added: '2.9'
+notes:
+- If you do not pass db parameter, tables will be created in the database
+ named postgres.
+- PostgreSQL allows to create columnless table, so columns param is optional.
+- Unlogged tables are available from PostgreSQL server version 9.1.
+seealso:
+- module: postgresql_sequence
+- module: postgresql_idx
+- module: postgresql_info
+- module: postgresql_tablespace
+- module: postgresql_owner
+- module: postgresql_privs
+- module: postgresql_copy
+- name: CREATE TABLE reference
+ description: Complete reference of the CREATE TABLE command documentation.
+ link: https://www.postgresql.org/docs/current/sql-createtable.html
+- name: ALTER TABLE reference
+ description: Complete reference of the ALTER TABLE command documentation.
+ link: https://www.postgresql.org/docs/current/sql-altertable.html
+- name: DROP TABLE reference
+ description: Complete reference of the DROP TABLE command documentation.
+ link: https://www.postgresql.org/docs/current/sql-droptable.html
+- name: PostgreSQL data types
+ description: Complete reference of the PostgreSQL data types documentation.
+ link: https://www.postgresql.org/docs/current/datatype.html
+author:
+- Andrei Klychkov (@Andersson007)
+extends_documentation_fragment: postgres
+'''
+
+EXAMPLES = r'''
+- name: Create tbl2 in the acme database with the DDL like tbl1 with testuser as an owner
+ postgresql_table:
+ db: acme
+ name: tbl2
+ like: tbl1
+ owner: testuser
+
+- name: Create tbl2 in the acme database and tablespace ssd with the DDL like tbl1 including comments and indexes
+ postgresql_table:
+ db: acme
+ table: tbl2
+ like: tbl1
+ including: comments, indexes
+ tablespace: ssd
+
+- name: Create test_table with several columns in ssd tablespace with fillfactor=10 and autovacuum_analyze_threshold=1
+ postgresql_table:
+ name: test_table
+ columns:
+ - id bigserial primary key
+ - num bigint
+ - stories text
+ tablespace: ssd
+ storage_params:
+ - fillfactor=10
+ - autovacuum_analyze_threshold=1
+
+- name: Create an unlogged table in schema acme
+ postgresql_table:
+ name: acme.useless_data
+ columns: waste_id int
+ unlogged: true
+
+- name: Rename table foo to bar
+ postgresql_table:
+ table: foo
+ rename: bar
+
+- name: Rename table foo from schema acme to bar
+ postgresql_table:
+ name: acme.foo
+ rename: bar
+
+- name: Set owner to someuser
+ postgresql_table:
+ name: foo
+ owner: someuser
+
+- name: Change tablespace of foo table to new_tablespace and set owner to new_user
+ postgresql_table:
+ name: foo
+ tablespace: new_tablespace
+ owner: new_user
+
+- name: Truncate table foo
+ postgresql_table:
+ name: foo
+ truncate: yes
+
+- name: Drop table foo from schema acme
+ postgresql_table:
+ name: acme.foo
+ state: absent
+
+- name: Drop table bar cascade
+ postgresql_table:
+ name: bar
+ state: absent
+ cascade: yes
+'''
+
+RETURN = r'''
+table:
+ description: Name of a table.
+ returned: always
+ type: str
+ sample: 'foo'
+state:
+ description: Table state.
+ returned: always
+ type: str
+ sample: 'present'
+owner:
+ description: Table owner.
+ returned: always
+ type: str
+ sample: 'postgres'
+tablespace:
+ description: Tablespace.
+ returned: always
+ type: str
+ sample: 'ssd_tablespace'
+queries:
+ description: List of executed queries.
+ returned: always
+ type: str
+ sample: [ 'CREATE TABLE "test_table" (id bigint)' ]
+storage_params:
+ description: Storage parameters.
+ returned: always
+ type: list
+ sample: [ "fillfactor=100", "autovacuum_analyze_threshold=1" ]
+'''
+
+try:
+ from psycopg2.extras import DictCursor
+except ImportError:
+ # psycopg2 is checked by connect_to_db()
+ # from ansible.module_utils.postgres
+ pass
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.database import pg_quote_identifier
+from ansible.module_utils.postgres import (
+ connect_to_db,
+ exec_sql,
+ get_conn_params,
+ postgres_common_argument_spec,
+)
+
+
+# ===========================================
+# PostgreSQL module specific support methods.
+#
+
+class Table(object):
+ def __init__(self, name, module, cursor):
+ self.name = name
+ self.module = module
+ self.cursor = cursor
+ self.info = {
+ 'owner': '',
+ 'tblspace': '',
+ 'storage_params': [],
+ }
+ self.exists = False
+ self.__exists_in_db()
+ self.executed_queries = []
+
+ def get_info(self):
+ """Getter to refresh and get table info"""
+ self.__exists_in_db()
+
+ def __exists_in_db(self):
+ """Check table exists and refresh info"""
+ if "." in self.name:
+ schema = self.name.split('.')[-2]
+ tblname = self.name.split('.')[-1]
+ else:
+ schema = 'public'
+ tblname = self.name
+
+ query = ("SELECT t.tableowner, t.tablespace, c.reloptions "
+ "FROM pg_tables AS t "
+ "INNER JOIN pg_class AS c ON c.relname = t.tablename "
+ "INNER JOIN pg_namespace AS n ON c.relnamespace = n.oid "
+ "WHERE t.tablename = %(tblname)s "
+ "AND n.nspname = %(schema)s")
+ res = exec_sql(self, query, query_params={'tblname': tblname, 'schema': schema},
+ add_to_executed=False)
+ if res:
+ self.exists = True
+ self.info = dict(
+ owner=res[0][0],
+ tblspace=res[0][1] if res[0][1] else '',
+ storage_params=res[0][2] if res[0][2] else [],
+ )
+
+ return True
+ else:
+ self.exists = False
+ return False
+
+ def create(self, columns='', params='', tblspace='',
+ unlogged=False, owner=''):
+ """
+ Create table.
+ If table exists, check passed args (params, tblspace, owner) and,
+ if they're different from current, change them.
+ Arguments:
+ params - storage params (passed by "WITH (...)" in SQL),
+ comma separated.
+ tblspace - tablespace.
+ owner - table owner.
+ unlogged - create unlogged table.
+ columns - column string (comma separated).
+ """
+ name = pg_quote_identifier(self.name, 'table')
+
+ changed = False
+
+ if self.exists:
+ if tblspace == 'pg_default' and self.info['tblspace'] is None:
+ pass # Because they have the same meaning
+ elif tblspace and self.info['tblspace'] != tblspace:
+ self.set_tblspace(tblspace)
+ changed = True
+
+ if owner and self.info['owner'] != owner:
+ self.set_owner(owner)
+ changed = True
+
+ if params:
+ param_list = [p.strip(' ') for p in params.split(',')]
+
+ new_param = False
+ for p in param_list:
+ if p not in self.info['storage_params']:
+ new_param = True
+
+ if new_param:
+ self.set_stor_params(params)
+ changed = True
+
+ if changed:
+ return True
+ return False
+
+ query = "CREATE"
+ if unlogged:
+ query += " UNLOGGED TABLE %s" % name
+ else:
+ query += " TABLE %s" % name
+
+ if columns:
+ query += " (%s)" % columns
+ else:
+ query += " ()"
+
+ if params:
+ query += " WITH (%s)" % params
+
+ if tblspace:
+ query += " TABLESPACE %s" % pg_quote_identifier(tblspace, 'database')
+
+ if exec_sql(self, query, ddl=True):
+ changed = True
+
+ if owner:
+ changed = self.set_owner(owner)
+
+ return changed
+
+ def create_like(self, src_table, including='', tblspace='',
+ unlogged=False, params='', owner=''):
+ """
+ Create table like another table (with similar DDL).
+ Arguments:
+ src_table - source table.
+ including - corresponds to optional INCLUDING expression
+ in CREATE TABLE ... LIKE statement.
+ params - storage params (passed by "WITH (...)" in SQL),
+ comma separated.
+ tblspace - tablespace.
+ owner - table owner.
+ unlogged - create unlogged table.
+ """
+ changed = False
+
+ name = pg_quote_identifier(self.name, 'table')
+
+ query = "CREATE"
+ if unlogged:
+ query += " UNLOGGED TABLE %s" % name
+ else:
+ query += " TABLE %s" % name
+
+ query += " (LIKE %s" % pg_quote_identifier(src_table, 'table')
+
+ if including:
+ including = including.split(',')
+ for i in including:
+ query += " INCLUDING %s" % i
+
+ query += ')'
+
+ if params:
+ query += " WITH (%s)" % params
+
+ if tblspace:
+ query += " TABLESPACE %s" % pg_quote_identifier(tblspace, 'database')
+
+ if exec_sql(self, query, ddl=True):
+ changed = True
+
+ if owner:
+ changed = self.set_owner(owner)
+
+ return changed
+
+ def truncate(self):
+ query = "TRUNCATE TABLE %s" % pg_quote_identifier(self.name, 'table')
+ return exec_sql(self, query, ddl=True)
+
+ def rename(self, newname):
+ query = "ALTER TABLE %s RENAME TO %s" % (pg_quote_identifier(self.name, 'table'),
+ pg_quote_identifier(newname, 'table'))
+ return exec_sql(self, query, ddl=True)
+
+ def set_owner(self, username):
+ query = "ALTER TABLE %s OWNER TO %s" % (pg_quote_identifier(self.name, 'table'),
+ pg_quote_identifier(username, 'role'))
+ return exec_sql(self, query, ddl=True)
+
+ def drop(self, cascade=False):
+ if not self.exists:
+ return False
+
+ query = "DROP TABLE %s" % pg_quote_identifier(self.name, 'table')
+ if cascade:
+ query += " CASCADE"
+ return exec_sql(self, query, ddl=True)
+
+ def set_tblspace(self, tblspace):
+ query = "ALTER TABLE %s SET TABLESPACE %s" % (pg_quote_identifier(self.name, 'table'),
+ pg_quote_identifier(tblspace, 'database'))
+ return exec_sql(self, query, ddl=True)
+
+ def set_stor_params(self, params):
+ query = "ALTER TABLE %s SET (%s)" % (pg_quote_identifier(self.name, 'table'), params)
+ return exec_sql(self, query, ddl=True)
+
+
+# ===========================================
+# Module execution.
+#
+
+
+def main():
+ argument_spec = postgres_common_argument_spec()
+ argument_spec.update(
+ table=dict(type='str', required=True, aliases=['name']),
+ state=dict(type='str', default="present", choices=["absent", "present"]),
+ db=dict(type='str', default='', aliases=['login_db']),
+ tablespace=dict(type='str'),
+ owner=dict(type='str'),
+ unlogged=dict(type='bool', default=False),
+ like=dict(type='str'),
+ including=dict(type='str'),
+ rename=dict(type='str'),
+ truncate=dict(type='bool', default=False),
+ columns=dict(type='list', elements='str'),
+ storage_params=dict(type='list', elements='str'),
+ session_role=dict(type='str'),
+ cascade=dict(type='bool', default=False),
+ )
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ supports_check_mode=True,
+ )
+
+ table = module.params["table"]
+ state = module.params["state"]
+ tablespace = module.params["tablespace"]
+ owner = module.params["owner"]
+ unlogged = module.params["unlogged"]
+ like = module.params["like"]
+ including = module.params["including"]
+ newname = module.params["rename"]
+ storage_params = module.params["storage_params"]
+ truncate = module.params["truncate"]
+ columns = module.params["columns"]
+ cascade = module.params["cascade"]
+
+ if state == 'present' and cascade:
+ module.warn("cascade=true is ignored when state=present")
+
+ # Check mutual exclusive parameters:
+ if state == 'absent' and (truncate or newname or columns or tablespace or like or storage_params or unlogged or owner or including):
+ module.fail_json(msg="%s: state=absent is mutually exclusive with: "
+ "truncate, rename, columns, tablespace, "
+ "including, like, storage_params, unlogged, owner" % table)
+
+ if truncate and (newname or columns or like or unlogged or storage_params or owner or tablespace or including):
+ module.fail_json(msg="%s: truncate is mutually exclusive with: "
+ "rename, columns, like, unlogged, including, "
+ "storage_params, owner, tablespace" % table)
+
+ if newname and (columns or like or unlogged or storage_params or owner or tablespace or including):
+ module.fail_json(msg="%s: rename is mutually exclusive with: "
+ "columns, like, unlogged, including, "
+ "storage_params, owner, tablespace" % table)
+
+ if like and columns:
+ module.fail_json(msg="%s: like and columns params are mutually exclusive" % table)
+ if including and not like:
+ module.fail_json(msg="%s: including param needs like param specified" % table)
+
+ conn_params = get_conn_params(module, module.params)
+ db_connection = connect_to_db(module, conn_params, autocommit=False)
+ cursor = db_connection.cursor(cursor_factory=DictCursor)
+
+ if storage_params:
+ storage_params = ','.join(storage_params)
+
+ if columns:
+ columns = ','.join(columns)
+
+ ##############
+ # Do main job:
+ table_obj = Table(table, module, cursor)
+
+ # Set default returned values:
+ changed = False
+ kw = {}
+ kw['table'] = table
+ kw['state'] = ''
+ if table_obj.exists:
+ kw = dict(
+ table=table,
+ state='present',
+ owner=table_obj.info['owner'],
+ tablespace=table_obj.info['tblspace'],
+ storage_params=table_obj.info['storage_params'],
+ )
+
+ if state == 'absent':
+ changed = table_obj.drop(cascade=cascade)
+
+ elif truncate:
+ changed = table_obj.truncate()
+
+ elif newname:
+ changed = table_obj.rename(newname)
+ q = table_obj.executed_queries
+ table_obj = Table(newname, module, cursor)
+ table_obj.executed_queries = q
+
+ elif state == 'present' and not like:
+ changed = table_obj.create(columns, storage_params,
+ tablespace, unlogged, owner)
+
+ elif state == 'present' and like:
+ changed = table_obj.create_like(like, including, tablespace,
+ unlogged, storage_params)
+
+ if changed:
+ if module.check_mode:
+ db_connection.rollback()
+ else:
+ db_connection.commit()
+
+ # Refresh table info for RETURN.
+ # Note, if table has been renamed, it gets info by newname:
+ table_obj.get_info()
+ db_connection.commit()
+ if table_obj.exists:
+ kw = dict(
+ table=table,
+ state='present',
+ owner=table_obj.info['owner'],
+ tablespace=table_obj.info['tblspace'],
+ storage_params=table_obj.info['storage_params'],
+ )
+ else:
+ # We just change the table state here
+ # to keep other information about the dropped table:
+ kw['state'] = 'absent'
+
+ kw['queries'] = table_obj.executed_queries
+ kw['changed'] = changed
+ db_connection.close()
+ module.exit_json(**kw)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/postgresql_user.py b/test/support/integration/plugins/modules/postgresql_user.py
new file mode 100644
index 00000000..10afd0a0
--- /dev/null
+++ b/test/support/integration/plugins/modules/postgresql_user.py
@@ -0,0 +1,927 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {
+ 'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'community'
+}
+
+DOCUMENTATION = r'''
+---
+module: postgresql_user
+short_description: Add or remove a user (role) from a PostgreSQL server instance
+description:
+- Adds or removes a user (role) from a PostgreSQL server instance
+ ("cluster" in PostgreSQL terminology) and, optionally,
+ grants the user access to an existing database or tables.
+- A user is a role with login privilege.
+- The fundamental function of the module is to create, or delete, users from
+ a PostgreSQL instances. Privilege assignment, or removal, is an optional
+ step, which works on one database at a time. This allows for the module to
+ be called several times in the same module to modify the permissions on
+ different databases, or to grant permissions to already existing users.
+- A user cannot be removed until all the privileges have been stripped from
+ the user. In such situation, if the module tries to remove the user it
+ will fail. To avoid this from happening the fail_on_user option signals
+ the module to try to remove the user, but if not possible keep going; the
+ module will report if changes happened and separately if the user was
+ removed or not.
+version_added: '0.6'
+options:
+ name:
+ description:
+ - Name of the user (role) to add or remove.
+ type: str
+ required: true
+ aliases:
+ - user
+ password:
+ description:
+ - Set the user's password, before 1.4 this was required.
+ - Password can be passed unhashed or hashed (MD5-hashed).
+ - Unhashed password will automatically be hashed when saved into the
+ database if C(encrypted) parameter is set, otherwise it will be save in
+ plain text format.
+ - When passing a hashed password it must be generated with the format
+ C('str["md5"] + md5[ password + username ]'), resulting in a total of
+ 35 characters. An easy way to do this is C(echo "md5$(echo -n
+ 'verysecretpasswordJOE' | md5sum | awk '{print $1}')").
+ - Note that if the provided password string is already in MD5-hashed
+ format, then it is used as-is, regardless of C(encrypted) parameter.
+ type: str
+ db:
+ description:
+ - Name of database to connect to and where user's permissions will be granted.
+ type: str
+ aliases:
+ - login_db
+ fail_on_user:
+ description:
+ - If C(yes), fail when user (role) can't be removed. Otherwise just log and continue.
+ default: 'yes'
+ type: bool
+ aliases:
+ - fail_on_role
+ priv:
+ description:
+ - "Slash-separated PostgreSQL privileges string: C(priv1/priv2), where
+ privileges can be defined for database ( allowed options - 'CREATE',
+ 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL'. For example C(CONNECT) ) or
+ for table ( allowed options - 'SELECT', 'INSERT', 'UPDATE', 'DELETE',
+ 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL'. For example
+ C(table:SELECT) ). Mixed example of this string:
+ C(CONNECT/CREATE/table1:SELECT/table2:INSERT)."
+ type: str
+ role_attr_flags:
+ description:
+ - "PostgreSQL user attributes string in the format: CREATEDB,CREATEROLE,SUPERUSER."
+ - Note that '[NO]CREATEUSER' is deprecated.
+ - To create a simple role for using it like a group, use C(NOLOGIN) flag.
+ type: str
+ choices: [ '[NO]SUPERUSER', '[NO]CREATEROLE', '[NO]CREATEDB',
+ '[NO]INHERIT', '[NO]LOGIN', '[NO]REPLICATION', '[NO]BYPASSRLS' ]
+ session_role:
+ version_added: '2.8'
+ description:
+ - Switch to session_role after connecting.
+ - The specified session_role must be a role that the current login_user is a member of.
+ - Permissions checking for SQL commands is carried out as though the session_role were the one that had logged in originally.
+ type: str
+ state:
+ description:
+ - The user (role) state.
+ type: str
+ default: present
+ choices: [ absent, present ]
+ encrypted:
+ description:
+ - Whether the password is stored hashed in the database.
+ - Passwords can be passed already hashed or unhashed, and postgresql
+ ensures the stored password is hashed when C(encrypted) is set.
+ - "Note: Postgresql 10 and newer doesn't support unhashed passwords."
+ - Previous to Ansible 2.6, this was C(no) by default.
+ default: 'yes'
+ type: bool
+ version_added: '1.4'
+ expires:
+ description:
+ - The date at which the user's password is to expire.
+ - If set to C('infinity'), user's password never expire.
+ - Note that this value should be a valid SQL date and time type.
+ type: str
+ version_added: '1.4'
+ no_password_changes:
+ description:
+ - If C(yes), don't inspect database for password changes. Effective when
+ C(pg_authid) is not accessible (such as AWS RDS). Otherwise, make
+ password changes as necessary.
+ default: 'no'
+ type: bool
+ version_added: '2.0'
+ conn_limit:
+ description:
+ - Specifies the user (role) connection limit.
+ type: int
+ version_added: '2.4'
+ ssl_mode:
+ description:
+ - Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated with the server.
+ - See https://www.postgresql.org/docs/current/static/libpq-ssl.html for more information on the modes.
+ - Default of C(prefer) matches libpq default.
+ type: str
+ default: prefer
+ choices: [ allow, disable, prefer, require, verify-ca, verify-full ]
+ version_added: '2.3'
+ ca_cert:
+ description:
+ - Specifies the name of a file containing SSL certificate authority (CA) certificate(s).
+ - If the file exists, the server's certificate will be verified to be signed by one of these authorities.
+ type: str
+ aliases: [ ssl_rootcert ]
+ version_added: '2.3'
+ groups:
+ description:
+ - The list of groups (roles) that need to be granted to the user.
+ type: list
+ elements: str
+ version_added: '2.9'
+ comment:
+ description:
+ - Add a comment on the user (equal to the COMMENT ON ROLE statement result).
+ type: str
+ version_added: '2.10'
+notes:
+- The module creates a user (role) with login privilege by default.
+ Use NOLOGIN role_attr_flags to change this behaviour.
+- If you specify PUBLIC as the user (role), then the privilege changes will apply to all users (roles).
+ You may not specify password or role_attr_flags when the PUBLIC user is specified.
+seealso:
+- module: postgresql_privs
+- module: postgresql_membership
+- module: postgresql_owner
+- name: PostgreSQL database roles
+ description: Complete reference of the PostgreSQL database roles documentation.
+ link: https://www.postgresql.org/docs/current/user-manag.html
+author:
+- Ansible Core Team
+extends_documentation_fragment: postgres
+'''
+
+EXAMPLES = r'''
+- name: Connect to acme database, create django user, and grant access to database and products table
+ postgresql_user:
+ db: acme
+ name: django
+ password: ceec4eif7ya
+ priv: "CONNECT/products:ALL"
+ expires: "Jan 31 2020"
+
+- name: Add a comment on django user
+ postgresql_user:
+ db: acme
+ name: django
+ comment: This is a test user
+
+# Connect to default database, create rails user, set its password (MD5-hashed),
+# and grant privilege to create other databases and demote rails from super user status if user exists
+- name: Create rails user, set MD5-hashed password, grant privs
+ postgresql_user:
+ name: rails
+ password: md59543f1d82624df2b31672ec0f7050460
+ role_attr_flags: CREATEDB,NOSUPERUSER
+
+- name: Connect to acme database and remove test user privileges from there
+ postgresql_user:
+ db: acme
+ name: test
+ priv: "ALL/products:ALL"
+ state: absent
+ fail_on_user: no
+
+- name: Connect to test database, remove test user from cluster
+ postgresql_user:
+ db: test
+ name: test
+ priv: ALL
+ state: absent
+
+- name: Connect to acme database and set user's password with no expire date
+ postgresql_user:
+ db: acme
+ name: django
+ password: mysupersecretword
+ priv: "CONNECT/products:ALL"
+ expires: infinity
+
+# Example privileges string format
+# INSERT,UPDATE/table:SELECT/anothertable:ALL
+
+- name: Connect to test database and remove an existing user's password
+ postgresql_user:
+ db: test
+ user: test
+ password: ""
+
+- name: Create user test and grant group user_ro and user_rw to it
+ postgresql_user:
+ name: test
+ groups:
+ - user_ro
+ - user_rw
+'''
+
+RETURN = r'''
+queries:
+ description: List of executed queries.
+ returned: always
+ type: list
+ sample: ['CREATE USER "alice"', 'GRANT CONNECT ON DATABASE "acme" TO "alice"']
+ version_added: '2.8'
+'''
+
+import itertools
+import re
+import traceback
+from hashlib import md5
+
+try:
+ import psycopg2
+ from psycopg2.extras import DictCursor
+except ImportError:
+ # psycopg2 is checked by connect_to_db()
+ # from ansible.module_utils.postgres
+ pass
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.database import pg_quote_identifier, SQLParseError
+from ansible.module_utils.postgres import (
+ connect_to_db,
+ get_conn_params,
+ PgMembership,
+ postgres_common_argument_spec,
+)
+from ansible.module_utils._text import to_bytes, to_native
+from ansible.module_utils.six import iteritems
+
+
+FLAGS = ('SUPERUSER', 'CREATEROLE', 'CREATEDB', 'INHERIT', 'LOGIN', 'REPLICATION')
+FLAGS_BY_VERSION = {'BYPASSRLS': 90500}
+
+VALID_PRIVS = dict(table=frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL')),
+ database=frozenset(
+ ('CREATE', 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL')),
+ )
+
+# map to cope with idiosyncracies of SUPERUSER and LOGIN
+PRIV_TO_AUTHID_COLUMN = dict(SUPERUSER='rolsuper', CREATEROLE='rolcreaterole',
+ CREATEDB='rolcreatedb', INHERIT='rolinherit', LOGIN='rolcanlogin',
+ REPLICATION='rolreplication', BYPASSRLS='rolbypassrls')
+
+executed_queries = []
+
+
+class InvalidFlagsError(Exception):
+ pass
+
+
+class InvalidPrivsError(Exception):
+ pass
+
+# ===========================================
+# PostgreSQL module specific support methods.
+#
+
+
+def user_exists(cursor, user):
+ # The PUBLIC user is a special case that is always there
+ if user == 'PUBLIC':
+ return True
+ query = "SELECT rolname FROM pg_roles WHERE rolname=%(user)s"
+ cursor.execute(query, {'user': user})
+ return cursor.rowcount > 0
+
+
+def user_add(cursor, user, password, role_attr_flags, encrypted, expires, conn_limit):
+ """Create a new database user (role)."""
+ # Note: role_attr_flags escaped by parse_role_attrs and encrypted is a
+ # literal
+ query_password_data = dict(password=password, expires=expires)
+ query = ['CREATE USER "%(user)s"' %
+ {"user": user}]
+ if password is not None and password != '':
+ query.append("WITH %(crypt)s" % {"crypt": encrypted})
+ query.append("PASSWORD %(password)s")
+ if expires is not None:
+ query.append("VALID UNTIL %(expires)s")
+ if conn_limit is not None:
+ query.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit})
+ query.append(role_attr_flags)
+ query = ' '.join(query)
+ executed_queries.append(query)
+ cursor.execute(query, query_password_data)
+ return True
+
+
+def user_should_we_change_password(current_role_attrs, user, password, encrypted):
+ """Check if we should change the user's password.
+
+ Compare the proposed password with the existing one, comparing
+ hashes if encrypted. If we can't access it assume yes.
+ """
+
+ if current_role_attrs is None:
+ # on some databases, E.g. AWS RDS instances, there is no access to
+ # the pg_authid relation to check the pre-existing password, so we
+ # just assume password is different
+ return True
+
+ # Do we actually need to do anything?
+ pwchanging = False
+ if password is not None:
+ # Empty password means that the role shouldn't have a password, which
+ # means we need to check if the current password is None.
+ if password == '':
+ if current_role_attrs['rolpassword'] is not None:
+ pwchanging = True
+ # 32: MD5 hashes are represented as a sequence of 32 hexadecimal digits
+ # 3: The size of the 'md5' prefix
+ # When the provided password looks like a MD5-hash, value of
+ # 'encrypted' is ignored.
+ elif (password.startswith('md5') and len(password) == 32 + 3) or encrypted == 'UNENCRYPTED':
+ if password != current_role_attrs['rolpassword']:
+ pwchanging = True
+ elif encrypted == 'ENCRYPTED':
+ hashed_password = 'md5{0}'.format(md5(to_bytes(password) + to_bytes(user)).hexdigest())
+ if hashed_password != current_role_attrs['rolpassword']:
+ pwchanging = True
+
+ return pwchanging
+
+
+def user_alter(db_connection, module, user, password, role_attr_flags, encrypted, expires, no_password_changes, conn_limit):
+ """Change user password and/or attributes. Return True if changed, False otherwise."""
+ changed = False
+
+ cursor = db_connection.cursor(cursor_factory=DictCursor)
+ # Note: role_attr_flags escaped by parse_role_attrs and encrypted is a
+ # literal
+ if user == 'PUBLIC':
+ if password is not None:
+ module.fail_json(msg="cannot change the password for PUBLIC user")
+ elif role_attr_flags != '':
+ module.fail_json(msg="cannot change the role_attr_flags for PUBLIC user")
+ else:
+ return False
+
+ # Handle passwords.
+ if not no_password_changes and (password is not None or role_attr_flags != '' or expires is not None or conn_limit is not None):
+ # Select password and all flag-like columns in order to verify changes.
+ try:
+ select = "SELECT * FROM pg_authid where rolname=%(user)s"
+ cursor.execute(select, {"user": user})
+ # Grab current role attributes.
+ current_role_attrs = cursor.fetchone()
+ except psycopg2.ProgrammingError:
+ current_role_attrs = None
+ db_connection.rollback()
+
+ pwchanging = user_should_we_change_password(current_role_attrs, user, password, encrypted)
+
+ if current_role_attrs is None:
+ try:
+ # AWS RDS instances does not allow user to access pg_authid
+ # so try to get current_role_attrs from pg_roles tables
+ select = "SELECT * FROM pg_roles where rolname=%(user)s"
+ cursor.execute(select, {"user": user})
+ # Grab current role attributes from pg_roles
+ current_role_attrs = cursor.fetchone()
+ except psycopg2.ProgrammingError as e:
+ db_connection.rollback()
+ module.fail_json(msg="Failed to get role details for current user %s: %s" % (user, e))
+
+ role_attr_flags_changing = False
+ if role_attr_flags:
+ role_attr_flags_dict = {}
+ for r in role_attr_flags.split(' '):
+ if r.startswith('NO'):
+ role_attr_flags_dict[r.replace('NO', '', 1)] = False
+ else:
+ role_attr_flags_dict[r] = True
+
+ for role_attr_name, role_attr_value in role_attr_flags_dict.items():
+ if current_role_attrs[PRIV_TO_AUTHID_COLUMN[role_attr_name]] != role_attr_value:
+ role_attr_flags_changing = True
+
+ if expires is not None:
+ cursor.execute("SELECT %s::timestamptz;", (expires,))
+ expires_with_tz = cursor.fetchone()[0]
+ expires_changing = expires_with_tz != current_role_attrs.get('rolvaliduntil')
+ else:
+ expires_changing = False
+
+ conn_limit_changing = (conn_limit is not None and conn_limit != current_role_attrs['rolconnlimit'])
+
+ if not pwchanging and not role_attr_flags_changing and not expires_changing and not conn_limit_changing:
+ return False
+
+ alter = ['ALTER USER "%(user)s"' % {"user": user}]
+ if pwchanging:
+ if password != '':
+ alter.append("WITH %(crypt)s" % {"crypt": encrypted})
+ alter.append("PASSWORD %(password)s")
+ else:
+ alter.append("WITH PASSWORD NULL")
+ alter.append(role_attr_flags)
+ elif role_attr_flags:
+ alter.append('WITH %s' % role_attr_flags)
+ if expires is not None:
+ alter.append("VALID UNTIL %(expires)s")
+ if conn_limit is not None:
+ alter.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit})
+
+ query_password_data = dict(password=password, expires=expires)
+ try:
+ cursor.execute(' '.join(alter), query_password_data)
+ changed = True
+ except psycopg2.InternalError as e:
+ if e.pgcode == '25006':
+ # Handle errors due to read-only transactions indicated by pgcode 25006
+ # ERROR: cannot execute ALTER ROLE in a read-only transaction
+ changed = False
+ module.fail_json(msg=e.pgerror, exception=traceback.format_exc())
+ return changed
+ else:
+ raise psycopg2.InternalError(e)
+ except psycopg2.NotSupportedError as e:
+ module.fail_json(msg=e.pgerror, exception=traceback.format_exc())
+
+ elif no_password_changes and role_attr_flags != '':
+ # Grab role information from pg_roles instead of pg_authid
+ select = "SELECT * FROM pg_roles where rolname=%(user)s"
+ cursor.execute(select, {"user": user})
+ # Grab current role attributes.
+ current_role_attrs = cursor.fetchone()
+
+ role_attr_flags_changing = False
+
+ if role_attr_flags:
+ role_attr_flags_dict = {}
+ for r in role_attr_flags.split(' '):
+ if r.startswith('NO'):
+ role_attr_flags_dict[r.replace('NO', '', 1)] = False
+ else:
+ role_attr_flags_dict[r] = True
+
+ for role_attr_name, role_attr_value in role_attr_flags_dict.items():
+ if current_role_attrs[PRIV_TO_AUTHID_COLUMN[role_attr_name]] != role_attr_value:
+ role_attr_flags_changing = True
+
+ if not role_attr_flags_changing:
+ return False
+
+ alter = ['ALTER USER "%(user)s"' %
+ {"user": user}]
+ if role_attr_flags:
+ alter.append('WITH %s' % role_attr_flags)
+
+ try:
+ cursor.execute(' '.join(alter))
+ except psycopg2.InternalError as e:
+ if e.pgcode == '25006':
+ # Handle errors due to read-only transactions indicated by pgcode 25006
+ # ERROR: cannot execute ALTER ROLE in a read-only transaction
+ changed = False
+ module.fail_json(msg=e.pgerror, exception=traceback.format_exc())
+ return changed
+ else:
+ raise psycopg2.InternalError(e)
+
+ # Grab new role attributes.
+ cursor.execute(select, {"user": user})
+ new_role_attrs = cursor.fetchone()
+
+ # Detect any differences between current_ and new_role_attrs.
+ changed = current_role_attrs != new_role_attrs
+
+ return changed
+
+
+def user_delete(cursor, user):
+ """Try to remove a user. Returns True if successful otherwise False"""
+ cursor.execute("SAVEPOINT ansible_pgsql_user_delete")
+ try:
+ query = 'DROP USER "%s"' % user
+ executed_queries.append(query)
+ cursor.execute(query)
+ except Exception:
+ cursor.execute("ROLLBACK TO SAVEPOINT ansible_pgsql_user_delete")
+ cursor.execute("RELEASE SAVEPOINT ansible_pgsql_user_delete")
+ return False
+
+ cursor.execute("RELEASE SAVEPOINT ansible_pgsql_user_delete")
+ return True
+
+
+def has_table_privileges(cursor, user, table, privs):
+ """
+ Return the difference between the privileges that a user already has and
+ the privileges that they desire to have.
+
+ :returns: tuple of:
+ * privileges that they have and were requested
+ * privileges they currently hold but were not requested
+ * privileges requested that they do not hold
+ """
+ cur_privs = get_table_privileges(cursor, user, table)
+ have_currently = cur_privs.intersection(privs)
+ other_current = cur_privs.difference(privs)
+ desired = privs.difference(cur_privs)
+ return (have_currently, other_current, desired)
+
+
+def get_table_privileges(cursor, user, table):
+ if '.' in table:
+ schema, table = table.split('.', 1)
+ else:
+ schema = 'public'
+ query = ("SELECT privilege_type FROM information_schema.role_table_grants "
+ "WHERE grantee=%(user)s AND table_name=%(table)s AND table_schema=%(schema)s")
+ cursor.execute(query, {'user': user, 'table': table, 'schema': schema})
+ return frozenset([x[0] for x in cursor.fetchall()])
+
+
+def grant_table_privileges(cursor, user, table, privs):
+ # Note: priv escaped by parse_privs
+ privs = ', '.join(privs)
+ query = 'GRANT %s ON TABLE %s TO "%s"' % (
+ privs, pg_quote_identifier(table, 'table'), user)
+ executed_queries.append(query)
+ cursor.execute(query)
+
+
+def revoke_table_privileges(cursor, user, table, privs):
+ # Note: priv escaped by parse_privs
+ privs = ', '.join(privs)
+ query = 'REVOKE %s ON TABLE %s FROM "%s"' % (
+ privs, pg_quote_identifier(table, 'table'), user)
+ executed_queries.append(query)
+ cursor.execute(query)
+
+
+def get_database_privileges(cursor, user, db):
+ priv_map = {
+ 'C': 'CREATE',
+ 'T': 'TEMPORARY',
+ 'c': 'CONNECT',
+ }
+ query = 'SELECT datacl FROM pg_database WHERE datname = %s'
+ cursor.execute(query, (db,))
+ datacl = cursor.fetchone()[0]
+ if datacl is None:
+ return set()
+ r = re.search(r'%s\\?"?=(C?T?c?)/[^,]+,?' % user, datacl)
+ if r is None:
+ return set()
+ o = set()
+ for v in r.group(1):
+ o.add(priv_map[v])
+ return normalize_privileges(o, 'database')
+
+
+def has_database_privileges(cursor, user, db, privs):
+ """
+ Return the difference between the privileges that a user already has and
+ the privileges that they desire to have.
+
+ :returns: tuple of:
+ * privileges that they have and were requested
+ * privileges they currently hold but were not requested
+ * privileges requested that they do not hold
+ """
+ cur_privs = get_database_privileges(cursor, user, db)
+ have_currently = cur_privs.intersection(privs)
+ other_current = cur_privs.difference(privs)
+ desired = privs.difference(cur_privs)
+ return (have_currently, other_current, desired)
+
+
+def grant_database_privileges(cursor, user, db, privs):
+ # Note: priv escaped by parse_privs
+ privs = ', '.join(privs)
+ if user == "PUBLIC":
+ query = 'GRANT %s ON DATABASE %s TO PUBLIC' % (
+ privs, pg_quote_identifier(db, 'database'))
+ else:
+ query = 'GRANT %s ON DATABASE %s TO "%s"' % (
+ privs, pg_quote_identifier(db, 'database'), user)
+
+ executed_queries.append(query)
+ cursor.execute(query)
+
+
+def revoke_database_privileges(cursor, user, db, privs):
+ # Note: priv escaped by parse_privs
+ privs = ', '.join(privs)
+ if user == "PUBLIC":
+ query = 'REVOKE %s ON DATABASE %s FROM PUBLIC' % (
+ privs, pg_quote_identifier(db, 'database'))
+ else:
+ query = 'REVOKE %s ON DATABASE %s FROM "%s"' % (
+ privs, pg_quote_identifier(db, 'database'), user)
+
+ executed_queries.append(query)
+ cursor.execute(query)
+
+
+def revoke_privileges(cursor, user, privs):
+ if privs is None:
+ return False
+
+ revoke_funcs = dict(table=revoke_table_privileges,
+ database=revoke_database_privileges)
+ check_funcs = dict(table=has_table_privileges,
+ database=has_database_privileges)
+
+ changed = False
+ for type_ in privs:
+ for name, privileges in iteritems(privs[type_]):
+ # Check that any of the privileges requested to be removed are
+ # currently granted to the user
+ differences = check_funcs[type_](cursor, user, name, privileges)
+ if differences[0]:
+ revoke_funcs[type_](cursor, user, name, privileges)
+ changed = True
+ return changed
+
+
+def grant_privileges(cursor, user, privs):
+ if privs is None:
+ return False
+
+ grant_funcs = dict(table=grant_table_privileges,
+ database=grant_database_privileges)
+ check_funcs = dict(table=has_table_privileges,
+ database=has_database_privileges)
+
+ changed = False
+ for type_ in privs:
+ for name, privileges in iteritems(privs[type_]):
+ # Check that any of the privileges requested for the user are
+ # currently missing
+ differences = check_funcs[type_](cursor, user, name, privileges)
+ if differences[2]:
+ grant_funcs[type_](cursor, user, name, privileges)
+ changed = True
+ return changed
+
+
+def parse_role_attrs(cursor, role_attr_flags):
+ """
+ Parse role attributes string for user creation.
+ Format:
+
+ attributes[,attributes,...]
+
+ Where:
+
+ attributes := CREATEDB,CREATEROLE,NOSUPERUSER,...
+ [ "[NO]SUPERUSER","[NO]CREATEROLE", "[NO]CREATEDB",
+ "[NO]INHERIT", "[NO]LOGIN", "[NO]REPLICATION",
+ "[NO]BYPASSRLS" ]
+
+ Note: "[NO]BYPASSRLS" role attribute introduced in 9.5
+ Note: "[NO]CREATEUSER" role attribute is deprecated.
+
+ """
+ flags = frozenset(role.upper() for role in role_attr_flags.split(',') if role)
+
+ valid_flags = frozenset(itertools.chain(FLAGS, get_valid_flags_by_version(cursor)))
+ valid_flags = frozenset(itertools.chain(valid_flags, ('NO%s' % flag for flag in valid_flags)))
+
+ if not flags.issubset(valid_flags):
+ raise InvalidFlagsError('Invalid role_attr_flags specified: %s' %
+ ' '.join(flags.difference(valid_flags)))
+
+ return ' '.join(flags)
+
+
+def normalize_privileges(privs, type_):
+ new_privs = set(privs)
+ if 'ALL' in new_privs:
+ new_privs.update(VALID_PRIVS[type_])
+ new_privs.remove('ALL')
+ if 'TEMP' in new_privs:
+ new_privs.add('TEMPORARY')
+ new_privs.remove('TEMP')
+
+ return new_privs
+
+
+def parse_privs(privs, db):
+ """
+ Parse privilege string to determine permissions for database db.
+ Format:
+
+ privileges[/privileges/...]
+
+ Where:
+
+ privileges := DATABASE_PRIVILEGES[,DATABASE_PRIVILEGES,...] |
+ TABLE_NAME:TABLE_PRIVILEGES[,TABLE_PRIVILEGES,...]
+ """
+ if privs is None:
+ return privs
+
+ o_privs = {
+ 'database': {},
+ 'table': {}
+ }
+ for token in privs.split('/'):
+ if ':' not in token:
+ type_ = 'database'
+ name = db
+ priv_set = frozenset(x.strip().upper()
+ for x in token.split(',') if x.strip())
+ else:
+ type_ = 'table'
+ name, privileges = token.split(':', 1)
+ priv_set = frozenset(x.strip().upper()
+ for x in privileges.split(',') if x.strip())
+
+ if not priv_set.issubset(VALID_PRIVS[type_]):
+ raise InvalidPrivsError('Invalid privs specified for %s: %s' %
+ (type_, ' '.join(priv_set.difference(VALID_PRIVS[type_]))))
+
+ priv_set = normalize_privileges(priv_set, type_)
+ o_privs[type_][name] = priv_set
+
+ return o_privs
+
+
+def get_valid_flags_by_version(cursor):
+ """
+ Some role attributes were introduced after certain versions. We want to
+ compile a list of valid flags against the current Postgres version.
+ """
+ current_version = cursor.connection.server_version
+
+ return [
+ flag
+ for flag, version_introduced in FLAGS_BY_VERSION.items()
+ if current_version >= version_introduced
+ ]
+
+
+def get_comment(cursor, user):
+ """Get user's comment."""
+ query = ("SELECT pg_catalog.shobj_description(r.oid, 'pg_authid') "
+ "FROM pg_catalog.pg_roles r "
+ "WHERE r.rolname = %(user)s")
+ cursor.execute(query, {'user': user})
+ return cursor.fetchone()[0]
+
+
+def add_comment(cursor, user, comment):
+ """Add comment on user."""
+ if comment != get_comment(cursor, user):
+ query = 'COMMENT ON ROLE "%s" IS ' % user
+ cursor.execute(query + '%(comment)s', {'comment': comment})
+ executed_queries.append(cursor.mogrify(query + '%(comment)s', {'comment': comment}))
+ return True
+ else:
+ return False
+
+
+# ===========================================
+# Module execution.
+#
+
+def main():
+ argument_spec = postgres_common_argument_spec()
+ argument_spec.update(
+ user=dict(type='str', required=True, aliases=['name']),
+ password=dict(type='str', default=None, no_log=True),
+ state=dict(type='str', default='present', choices=['absent', 'present']),
+ priv=dict(type='str', default=None),
+ db=dict(type='str', default='', aliases=['login_db']),
+ fail_on_user=dict(type='bool', default='yes', aliases=['fail_on_role']),
+ role_attr_flags=dict(type='str', default=''),
+ encrypted=dict(type='bool', default='yes'),
+ no_password_changes=dict(type='bool', default='no'),
+ expires=dict(type='str', default=None),
+ conn_limit=dict(type='int', default=None),
+ session_role=dict(type='str'),
+ groups=dict(type='list', elements='str'),
+ comment=dict(type='str', default=None),
+ )
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ supports_check_mode=True
+ )
+
+ user = module.params["user"]
+ password = module.params["password"]
+ state = module.params["state"]
+ fail_on_user = module.params["fail_on_user"]
+ if module.params['db'] == '' and module.params["priv"] is not None:
+ module.fail_json(msg="privileges require a database to be specified")
+ privs = parse_privs(module.params["priv"], module.params["db"])
+ no_password_changes = module.params["no_password_changes"]
+ if module.params["encrypted"]:
+ encrypted = "ENCRYPTED"
+ else:
+ encrypted = "UNENCRYPTED"
+ expires = module.params["expires"]
+ conn_limit = module.params["conn_limit"]
+ role_attr_flags = module.params["role_attr_flags"]
+ groups = module.params["groups"]
+ if groups:
+ groups = [e.strip() for e in groups]
+ comment = module.params["comment"]
+
+ conn_params = get_conn_params(module, module.params, warn_db_default=False)
+ db_connection = connect_to_db(module, conn_params)
+ cursor = db_connection.cursor(cursor_factory=DictCursor)
+
+ try:
+ role_attr_flags = parse_role_attrs(cursor, role_attr_flags)
+ except InvalidFlagsError as e:
+ module.fail_json(msg=to_native(e), exception=traceback.format_exc())
+
+ kw = dict(user=user)
+ changed = False
+ user_removed = False
+
+ if state == "present":
+ if user_exists(cursor, user):
+ try:
+ changed = user_alter(db_connection, module, user, password,
+ role_attr_flags, encrypted, expires, no_password_changes, conn_limit)
+ except SQLParseError as e:
+ module.fail_json(msg=to_native(e), exception=traceback.format_exc())
+ else:
+ try:
+ changed = user_add(cursor, user, password,
+ role_attr_flags, encrypted, expires, conn_limit)
+ except psycopg2.ProgrammingError as e:
+ module.fail_json(msg="Unable to add user with given requirement "
+ "due to : %s" % to_native(e),
+ exception=traceback.format_exc())
+ except SQLParseError as e:
+ module.fail_json(msg=to_native(e), exception=traceback.format_exc())
+ try:
+ changed = grant_privileges(cursor, user, privs) or changed
+ except SQLParseError as e:
+ module.fail_json(msg=to_native(e), exception=traceback.format_exc())
+
+ if groups:
+ target_roles = []
+ target_roles.append(user)
+ pg_membership = PgMembership(module, cursor, groups, target_roles)
+ changed = pg_membership.grant() or changed
+ executed_queries.extend(pg_membership.executed_queries)
+
+ if comment is not None:
+ try:
+ changed = add_comment(cursor, user, comment) or changed
+ except Exception as e:
+ module.fail_json(msg='Unable to add comment on role: %s' % to_native(e),
+ exception=traceback.format_exc())
+
+ else:
+ if user_exists(cursor, user):
+ if module.check_mode:
+ changed = True
+ kw['user_removed'] = True
+ else:
+ try:
+ changed = revoke_privileges(cursor, user, privs)
+ user_removed = user_delete(cursor, user)
+ except SQLParseError as e:
+ module.fail_json(msg=to_native(e), exception=traceback.format_exc())
+ changed = changed or user_removed
+ if fail_on_user and not user_removed:
+ msg = "Unable to remove user"
+ module.fail_json(msg=msg)
+ kw['user_removed'] = user_removed
+
+ if changed:
+ if module.check_mode:
+ db_connection.rollback()
+ else:
+ db_connection.commit()
+
+ kw['changed'] = changed
+ kw['queries'] = executed_queries
+ module.exit_json(**kw)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/rabbitmq_plugin.py b/test/support/integration/plugins/modules/rabbitmq_plugin.py
new file mode 100644
index 00000000..301bbfe2
--- /dev/null
+++ b/test/support/integration/plugins/modules/rabbitmq_plugin.py
@@ -0,0 +1,180 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2013, Chatham Financial <oss@chathamfinancial.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {
+ 'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'
+}
+
+
+DOCUMENTATION = '''
+---
+module: rabbitmq_plugin
+short_description: Manage RabbitMQ plugins
+description:
+ - This module can be used to enable or disable RabbitMQ plugins.
+version_added: "1.1"
+author:
+ - Chris Hoffman (@chrishoffman)
+options:
+ names:
+ description:
+ - Comma-separated list of plugin names. Also, accepts plugin name.
+ required: true
+ aliases: [name]
+ new_only:
+ description:
+ - Only enable missing plugins.
+ - Does not disable plugins that are not in the names list.
+ type: bool
+ default: "no"
+ state:
+ description:
+ - Specify if plugins are to be enabled or disabled.
+ default: enabled
+ choices: [enabled, disabled]
+ prefix:
+ description:
+ - Specify a custom install prefix to a Rabbit.
+ version_added: "1.3"
+'''
+
+EXAMPLES = '''
+- name: Enables the rabbitmq_management plugin
+ rabbitmq_plugin:
+ names: rabbitmq_management
+ state: enabled
+
+- name: Enable multiple rabbitmq plugins
+ rabbitmq_plugin:
+ names: rabbitmq_management,rabbitmq_management_visualiser
+ state: enabled
+
+- name: Disable plugin
+ rabbitmq_plugin:
+ names: rabbitmq_management
+ state: disabled
+
+- name: Enable every plugin in list with existing plugins
+ rabbitmq_plugin:
+ names: rabbitmq_management,rabbitmq_management_visualiser,rabbitmq_shovel,rabbitmq_shovel_management
+ state: enabled
+ new_only: 'yes'
+'''
+
+RETURN = '''
+enabled:
+ description: list of plugins enabled during task run
+ returned: always
+ type: list
+ sample: ["rabbitmq_management"]
+disabled:
+ description: list of plugins disabled during task run
+ returned: always
+ type: list
+ sample: ["rabbitmq_management"]
+'''
+
+import os
+from ansible.module_utils.basic import AnsibleModule
+
+
+class RabbitMqPlugins(object):
+
+ def __init__(self, module):
+ self.module = module
+ bin_path = ''
+ if module.params['prefix']:
+ if os.path.isdir(os.path.join(module.params['prefix'], 'bin')):
+ bin_path = os.path.join(module.params['prefix'], 'bin')
+ elif os.path.isdir(os.path.join(module.params['prefix'], 'sbin')):
+ bin_path = os.path.join(module.params['prefix'], 'sbin')
+ else:
+ # No such path exists.
+ module.fail_json(msg="No binary folder in prefix %s" % module.params['prefix'])
+
+ self._rabbitmq_plugins = os.path.join(bin_path, "rabbitmq-plugins")
+ else:
+ self._rabbitmq_plugins = module.get_bin_path('rabbitmq-plugins', True)
+
+ def _exec(self, args, run_in_check_mode=False):
+ if not self.module.check_mode or (self.module.check_mode and run_in_check_mode):
+ cmd = [self._rabbitmq_plugins]
+ rc, out, err = self.module.run_command(cmd + args, check_rc=True)
+ return out.splitlines()
+ return list()
+
+ def get_all(self):
+ list_output = self._exec(['list', '-E', '-m'], True)
+ plugins = []
+ for plugin in list_output:
+ if not plugin:
+ break
+ plugins.append(plugin)
+
+ return plugins
+
+ def enable(self, name):
+ self._exec(['enable', name])
+
+ def disable(self, name):
+ self._exec(['disable', name])
+
+
+def main():
+ arg_spec = dict(
+ names=dict(required=True, aliases=['name']),
+ new_only=dict(default='no', type='bool'),
+ state=dict(default='enabled', choices=['enabled', 'disabled']),
+ prefix=dict(required=False, default=None)
+ )
+ module = AnsibleModule(
+ argument_spec=arg_spec,
+ supports_check_mode=True
+ )
+
+ result = dict()
+ names = module.params['names'].split(',')
+ new_only = module.params['new_only']
+ state = module.params['state']
+
+ rabbitmq_plugins = RabbitMqPlugins(module)
+ enabled_plugins = rabbitmq_plugins.get_all()
+
+ enabled = []
+ disabled = []
+ if state == 'enabled':
+ if not new_only:
+ for plugin in enabled_plugins:
+ if " " in plugin:
+ continue
+ if plugin not in names:
+ rabbitmq_plugins.disable(plugin)
+ disabled.append(plugin)
+
+ for name in names:
+ if name not in enabled_plugins:
+ rabbitmq_plugins.enable(name)
+ enabled.append(name)
+ else:
+ for plugin in enabled_plugins:
+ if plugin in names:
+ rabbitmq_plugins.disable(plugin)
+ disabled.append(plugin)
+
+ result['changed'] = len(enabled) > 0 or len(disabled) > 0
+ result['enabled'] = enabled
+ result['disabled'] = disabled
+ module.exit_json(**result)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/rabbitmq_queue.py b/test/support/integration/plugins/modules/rabbitmq_queue.py
new file mode 100644
index 00000000..567ec813
--- /dev/null
+++ b/test/support/integration/plugins/modules/rabbitmq_queue.py
@@ -0,0 +1,257 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2015, Manuel Sousa <manuel.sousa@gmail.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: rabbitmq_queue
+author: Manuel Sousa (@manuel-sousa)
+version_added: "2.0"
+
+short_description: Manage rabbitMQ queues
+description:
+ - This module uses rabbitMQ Rest API to create/delete queues
+requirements: [ "requests >= 1.0.0" ]
+options:
+ name:
+ description:
+ - Name of the queue
+ required: true
+ state:
+ description:
+ - Whether the queue should be present or absent
+ choices: [ "present", "absent" ]
+ default: present
+ durable:
+ description:
+ - whether queue is durable or not
+ type: bool
+ default: 'yes'
+ auto_delete:
+ description:
+ - if the queue should delete itself after all queues/queues unbound from it
+ type: bool
+ default: 'no'
+ message_ttl:
+ description:
+ - How long a message can live in queue before it is discarded (milliseconds)
+ default: forever
+ auto_expires:
+ description:
+ - How long a queue can be unused before it is automatically deleted (milliseconds)
+ default: forever
+ max_length:
+ description:
+ - How many messages can the queue contain before it starts rejecting
+ default: no limit
+ dead_letter_exchange:
+ description:
+ - Optional name of an exchange to which messages will be republished if they
+ - are rejected or expire
+ dead_letter_routing_key:
+ description:
+ - Optional replacement routing key to use when a message is dead-lettered.
+ - Original routing key will be used if unset
+ max_priority:
+ description:
+ - Maximum number of priority levels for the queue to support.
+ - If not set, the queue will not support message priorities.
+ - Larger numbers indicate higher priority.
+ version_added: "2.4"
+ arguments:
+ description:
+ - extra arguments for queue. If defined this argument is a key/value dictionary
+ default: {}
+extends_documentation_fragment:
+ - rabbitmq
+'''
+
+EXAMPLES = '''
+# Create a queue
+- rabbitmq_queue:
+ name: myQueue
+
+# Create a queue on remote host
+- rabbitmq_queue:
+ name: myRemoteQueue
+ login_user: user
+ login_password: secret
+ login_host: remote.example.org
+'''
+
+import json
+import traceback
+
+REQUESTS_IMP_ERR = None
+try:
+ import requests
+ HAS_REQUESTS = True
+except ImportError:
+ REQUESTS_IMP_ERR = traceback.format_exc()
+ HAS_REQUESTS = False
+
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+from ansible.module_utils.six.moves.urllib import parse as urllib_parse
+from ansible.module_utils.rabbitmq import rabbitmq_argument_spec
+
+
+def main():
+
+ argument_spec = rabbitmq_argument_spec()
+ argument_spec.update(
+ dict(
+ state=dict(default='present', choices=['present', 'absent'], type='str'),
+ name=dict(required=True, type='str'),
+ durable=dict(default=True, type='bool'),
+ auto_delete=dict(default=False, type='bool'),
+ message_ttl=dict(default=None, type='int'),
+ auto_expires=dict(default=None, type='int'),
+ max_length=dict(default=None, type='int'),
+ dead_letter_exchange=dict(default=None, type='str'),
+ dead_letter_routing_key=dict(default=None, type='str'),
+ arguments=dict(default=dict(), type='dict'),
+ max_priority=dict(default=None, type='int')
+ )
+ )
+ module = AnsibleModule(argument_spec=argument_spec, supports_check_mode=True)
+
+ url = "%s://%s:%s/api/queues/%s/%s" % (
+ module.params['login_protocol'],
+ module.params['login_host'],
+ module.params['login_port'],
+ urllib_parse.quote(module.params['vhost'], ''),
+ module.params['name']
+ )
+
+ if not HAS_REQUESTS:
+ module.fail_json(msg=missing_required_lib("requests"), exception=REQUESTS_IMP_ERR)
+
+ result = dict(changed=False, name=module.params['name'])
+
+ # Check if queue already exists
+ r = requests.get(url, auth=(module.params['login_user'], module.params['login_password']),
+ verify=module.params['ca_cert'], cert=(module.params['client_cert'], module.params['client_key']))
+
+ if r.status_code == 200:
+ queue_exists = True
+ response = r.json()
+ elif r.status_code == 404:
+ queue_exists = False
+ response = r.text
+ else:
+ module.fail_json(
+ msg="Invalid response from RESTAPI when trying to check if queue exists",
+ details=r.text
+ )
+
+ if module.params['state'] == 'present':
+ change_required = not queue_exists
+ else:
+ change_required = queue_exists
+
+ # Check if attributes change on existing queue
+ if not change_required and r.status_code == 200 and module.params['state'] == 'present':
+ if not (
+ response['durable'] == module.params['durable'] and
+ response['auto_delete'] == module.params['auto_delete'] and
+ (
+ ('x-message-ttl' in response['arguments'] and response['arguments']['x-message-ttl'] == module.params['message_ttl']) or
+ ('x-message-ttl' not in response['arguments'] and module.params['message_ttl'] is None)
+ ) and
+ (
+ ('x-expires' in response['arguments'] and response['arguments']['x-expires'] == module.params['auto_expires']) or
+ ('x-expires' not in response['arguments'] and module.params['auto_expires'] is None)
+ ) and
+ (
+ ('x-max-length' in response['arguments'] and response['arguments']['x-max-length'] == module.params['max_length']) or
+ ('x-max-length' not in response['arguments'] and module.params['max_length'] is None)
+ ) and
+ (
+ ('x-dead-letter-exchange' in response['arguments'] and
+ response['arguments']['x-dead-letter-exchange'] == module.params['dead_letter_exchange']) or
+ ('x-dead-letter-exchange' not in response['arguments'] and module.params['dead_letter_exchange'] is None)
+ ) and
+ (
+ ('x-dead-letter-routing-key' in response['arguments'] and
+ response['arguments']['x-dead-letter-routing-key'] == module.params['dead_letter_routing_key']) or
+ ('x-dead-letter-routing-key' not in response['arguments'] and module.params['dead_letter_routing_key'] is None)
+ ) and
+ (
+ ('x-max-priority' in response['arguments'] and
+ response['arguments']['x-max-priority'] == module.params['max_priority']) or
+ ('x-max-priority' not in response['arguments'] and module.params['max_priority'] is None)
+ )
+ ):
+ module.fail_json(
+ msg="RabbitMQ RESTAPI doesn't support attribute changes for existing queues",
+ )
+
+ # Copy parameters to arguments as used by RabbitMQ
+ for k, v in {
+ 'message_ttl': 'x-message-ttl',
+ 'auto_expires': 'x-expires',
+ 'max_length': 'x-max-length',
+ 'dead_letter_exchange': 'x-dead-letter-exchange',
+ 'dead_letter_routing_key': 'x-dead-letter-routing-key',
+ 'max_priority': 'x-max-priority'
+ }.items():
+ if module.params[k] is not None:
+ module.params['arguments'][v] = module.params[k]
+
+ # Exit if check_mode
+ if module.check_mode:
+ result['changed'] = change_required
+ result['details'] = response
+ result['arguments'] = module.params['arguments']
+ module.exit_json(**result)
+
+ # Do changes
+ if change_required:
+ if module.params['state'] == 'present':
+ r = requests.put(
+ url,
+ auth=(module.params['login_user'], module.params['login_password']),
+ headers={"content-type": "application/json"},
+ data=json.dumps({
+ "durable": module.params['durable'],
+ "auto_delete": module.params['auto_delete'],
+ "arguments": module.params['arguments']
+ }),
+ verify=module.params['ca_cert'],
+ cert=(module.params['client_cert'], module.params['client_key'])
+ )
+ elif module.params['state'] == 'absent':
+ r = requests.delete(url, auth=(module.params['login_user'], module.params['login_password']),
+ verify=module.params['ca_cert'], cert=(module.params['client_cert'], module.params['client_key']))
+
+ # RabbitMQ 3.6.7 changed this response code from 204 to 201
+ if r.status_code == 204 or r.status_code == 201:
+ result['changed'] = True
+ module.exit_json(**result)
+ else:
+ module.fail_json(
+ msg="Error creating queue",
+ status=r.status_code,
+ details=r.text
+ )
+
+ else:
+ module.exit_json(
+ changed=False,
+ name=module.params['name']
+ )
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/s3_bucket.py b/test/support/integration/plugins/modules/s3_bucket.py
new file mode 100644
index 00000000..f35cf53b
--- /dev/null
+++ b/test/support/integration/plugins/modules/s3_bucket.py
@@ -0,0 +1,740 @@
+#!/usr/bin/python
+#
+# This is a free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This Ansible library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this library. If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+
+DOCUMENTATION = '''
+---
+module: s3_bucket
+short_description: Manage S3 buckets in AWS, DigitalOcean, Ceph, Walrus, FakeS3 and StorageGRID
+description:
+ - Manage S3 buckets in AWS, DigitalOcean, Ceph, Walrus, FakeS3 and StorageGRID
+version_added: "2.0"
+requirements: [ boto3 ]
+author: "Rob White (@wimnat)"
+options:
+ force:
+ description:
+ - When trying to delete a bucket, delete all keys (including versions and delete markers)
+ in the bucket first (an s3 bucket must be empty for a successful deletion)
+ type: bool
+ default: 'no'
+ name:
+ description:
+ - Name of the s3 bucket
+ required: true
+ type: str
+ policy:
+ description:
+ - The JSON policy as a string.
+ type: json
+ s3_url:
+ description:
+ - S3 URL endpoint for usage with DigitalOcean, Ceph, Eucalyptus and fakes3 etc.
+ - Assumes AWS if not specified.
+ - For Walrus, use FQDN of the endpoint without scheme nor path.
+ aliases: [ S3_URL ]
+ type: str
+ ceph:
+ description:
+ - Enable API compatibility with Ceph. It takes into account the S3 API subset working
+ with Ceph in order to provide the same module behaviour where possible.
+ type: bool
+ version_added: "2.2"
+ requester_pays:
+ description:
+ - With Requester Pays buckets, the requester instead of the bucket owner pays the cost
+ of the request and the data download from the bucket.
+ type: bool
+ default: False
+ state:
+ description:
+ - Create or remove the s3 bucket
+ required: false
+ default: present
+ choices: [ 'present', 'absent' ]
+ type: str
+ tags:
+ description:
+ - tags dict to apply to bucket
+ type: dict
+ purge_tags:
+ description:
+ - whether to remove tags that aren't present in the C(tags) parameter
+ type: bool
+ default: True
+ version_added: "2.9"
+ versioning:
+ description:
+ - Whether versioning is enabled or disabled (note that once versioning is enabled, it can only be suspended)
+ type: bool
+ encryption:
+ description:
+ - Describes the default server-side encryption to apply to new objects in the bucket.
+ In order to remove the server-side encryption, the encryption needs to be set to 'none' explicitly.
+ choices: [ 'none', 'AES256', 'aws:kms' ]
+ version_added: "2.9"
+ type: str
+ encryption_key_id:
+ description: KMS master key ID to use for the default encryption. This parameter is allowed if encryption is aws:kms. If
+ not specified then it will default to the AWS provided KMS key.
+ version_added: "2.9"
+ type: str
+extends_documentation_fragment:
+ - aws
+ - ec2
+notes:
+ - If C(requestPayment), C(policy), C(tagging) or C(versioning)
+ operations/API aren't implemented by the endpoint, module doesn't fail
+ if each parameter satisfies the following condition.
+ I(requester_pays) is C(False), I(policy), I(tags), and I(versioning) are C(None).
+'''
+
+EXAMPLES = '''
+# Note: These examples do not set authentication details, see the AWS Guide for details.
+
+# Create a simple s3 bucket
+- s3_bucket:
+ name: mys3bucket
+ state: present
+
+# Create a simple s3 bucket on Ceph Rados Gateway
+- s3_bucket:
+ name: mys3bucket
+ s3_url: http://your-ceph-rados-gateway-server.xxx
+ ceph: true
+
+# Remove an s3 bucket and any keys it contains
+- s3_bucket:
+ name: mys3bucket
+ state: absent
+ force: yes
+
+# Create a bucket, add a policy from a file, enable requester pays, enable versioning and tag
+- s3_bucket:
+ name: mys3bucket
+ policy: "{{ lookup('file','policy.json') }}"
+ requester_pays: yes
+ versioning: yes
+ tags:
+ example: tag1
+ another: tag2
+
+# Create a simple DigitalOcean Spaces bucket using their provided regional endpoint
+- s3_bucket:
+ name: mydobucket
+ s3_url: 'https://nyc3.digitaloceanspaces.com'
+
+# Create a bucket with AES256 encryption
+- s3_bucket:
+ name: mys3bucket
+ state: present
+ encryption: "AES256"
+
+# Create a bucket with aws:kms encryption, KMS key
+- s3_bucket:
+ name: mys3bucket
+ state: present
+ encryption: "aws:kms"
+ encryption_key_id: "arn:aws:kms:us-east-1:1234/5678example"
+
+# Create a bucket with aws:kms encryption, default key
+- s3_bucket:
+ name: mys3bucket
+ state: present
+ encryption: "aws:kms"
+'''
+
+import json
+import os
+import time
+
+from ansible.module_utils.six.moves.urllib.parse import urlparse
+from ansible.module_utils.six import string_types
+from ansible.module_utils.basic import to_text
+from ansible.module_utils.aws.core import AnsibleAWSModule, is_boto3_error_code
+from ansible.module_utils.ec2 import compare_policies, ec2_argument_spec, boto3_tag_list_to_ansible_dict, ansible_dict_to_boto3_tag_list
+from ansible.module_utils.ec2 import get_aws_connection_info, boto3_conn, AWSRetry
+
+try:
+ from botocore.exceptions import BotoCoreError, ClientError, EndpointConnectionError, WaiterError
+except ImportError:
+ pass # handled by AnsibleAWSModule
+
+
+def create_or_update_bucket(s3_client, module, location):
+
+ policy = module.params.get("policy")
+ name = module.params.get("name")
+ requester_pays = module.params.get("requester_pays")
+ tags = module.params.get("tags")
+ purge_tags = module.params.get("purge_tags")
+ versioning = module.params.get("versioning")
+ encryption = module.params.get("encryption")
+ encryption_key_id = module.params.get("encryption_key_id")
+ changed = False
+ result = {}
+
+ try:
+ bucket_is_present = bucket_exists(s3_client, name)
+ except EndpointConnectionError as e:
+ module.fail_json_aws(e, msg="Invalid endpoint provided: %s" % to_text(e))
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to check bucket presence")
+
+ if not bucket_is_present:
+ try:
+ bucket_changed = create_bucket(s3_client, name, location)
+ s3_client.get_waiter('bucket_exists').wait(Bucket=name)
+ changed = changed or bucket_changed
+ except WaiterError as e:
+ module.fail_json_aws(e, msg='An error occurred waiting for the bucket to become available')
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed while creating bucket")
+
+ # Versioning
+ try:
+ versioning_status = get_bucket_versioning(s3_client, name)
+ except BotoCoreError as exp:
+ module.fail_json_aws(exp, msg="Failed to get bucket versioning")
+ except ClientError as exp:
+ if exp.response['Error']['Code'] != 'NotImplemented' or versioning is not None:
+ module.fail_json_aws(exp, msg="Failed to get bucket versioning")
+ else:
+ if versioning is not None:
+ required_versioning = None
+ if versioning and versioning_status.get('Status') != "Enabled":
+ required_versioning = 'Enabled'
+ elif not versioning and versioning_status.get('Status') == "Enabled":
+ required_versioning = 'Suspended'
+
+ if required_versioning:
+ try:
+ put_bucket_versioning(s3_client, name, required_versioning)
+ changed = True
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to update bucket versioning")
+
+ versioning_status = wait_versioning_is_applied(module, s3_client, name, required_versioning)
+
+ # This output format is there to ensure compatibility with previous versions of the module
+ result['versioning'] = {
+ 'Versioning': versioning_status.get('Status', 'Disabled'),
+ 'MfaDelete': versioning_status.get('MFADelete', 'Disabled'),
+ }
+
+ # Requester pays
+ try:
+ requester_pays_status = get_bucket_request_payment(s3_client, name)
+ except BotoCoreError as exp:
+ module.fail_json_aws(exp, msg="Failed to get bucket request payment")
+ except ClientError as exp:
+ if exp.response['Error']['Code'] not in ('NotImplemented', 'XNotImplemented') or requester_pays:
+ module.fail_json_aws(exp, msg="Failed to get bucket request payment")
+ else:
+ if requester_pays:
+ payer = 'Requester' if requester_pays else 'BucketOwner'
+ if requester_pays_status != payer:
+ put_bucket_request_payment(s3_client, name, payer)
+ requester_pays_status = wait_payer_is_applied(module, s3_client, name, payer, should_fail=False)
+ if requester_pays_status is None:
+ # We have seen that it happens quite a lot of times that the put request was not taken into
+ # account, so we retry one more time
+ put_bucket_request_payment(s3_client, name, payer)
+ requester_pays_status = wait_payer_is_applied(module, s3_client, name, payer, should_fail=True)
+ changed = True
+
+ result['requester_pays'] = requester_pays
+
+ # Policy
+ try:
+ current_policy = get_bucket_policy(s3_client, name)
+ except BotoCoreError as exp:
+ module.fail_json_aws(exp, msg="Failed to get bucket policy")
+ except ClientError as exp:
+ if exp.response['Error']['Code'] != 'NotImplemented' or policy is not None:
+ module.fail_json_aws(exp, msg="Failed to get bucket policy")
+ else:
+ if policy is not None:
+ if isinstance(policy, string_types):
+ policy = json.loads(policy)
+
+ if not policy and current_policy:
+ try:
+ delete_bucket_policy(s3_client, name)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to delete bucket policy")
+ current_policy = wait_policy_is_applied(module, s3_client, name, policy)
+ changed = True
+ elif compare_policies(current_policy, policy):
+ try:
+ put_bucket_policy(s3_client, name, policy)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to update bucket policy")
+ current_policy = wait_policy_is_applied(module, s3_client, name, policy, should_fail=False)
+ if current_policy is None:
+ # As for request payement, it happens quite a lot of times that the put request was not taken into
+ # account, so we retry one more time
+ put_bucket_policy(s3_client, name, policy)
+ current_policy = wait_policy_is_applied(module, s3_client, name, policy, should_fail=True)
+ changed = True
+
+ result['policy'] = current_policy
+
+ # Tags
+ try:
+ current_tags_dict = get_current_bucket_tags_dict(s3_client, name)
+ except BotoCoreError as exp:
+ module.fail_json_aws(exp, msg="Failed to get bucket tags")
+ except ClientError as exp:
+ if exp.response['Error']['Code'] not in ('NotImplemented', 'XNotImplemented') or tags is not None:
+ module.fail_json_aws(exp, msg="Failed to get bucket tags")
+ else:
+ if tags is not None:
+ # Tags are always returned as text
+ tags = dict((to_text(k), to_text(v)) for k, v in tags.items())
+ if not purge_tags:
+ # Ensure existing tags that aren't updated by desired tags remain
+ current_copy = current_tags_dict.copy()
+ current_copy.update(tags)
+ tags = current_copy
+ if current_tags_dict != tags:
+ if tags:
+ try:
+ put_bucket_tagging(s3_client, name, tags)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to update bucket tags")
+ else:
+ if purge_tags:
+ try:
+ delete_bucket_tagging(s3_client, name)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to delete bucket tags")
+ current_tags_dict = wait_tags_are_applied(module, s3_client, name, tags)
+ changed = True
+
+ result['tags'] = current_tags_dict
+
+ # Encryption
+ if hasattr(s3_client, "get_bucket_encryption"):
+ try:
+ current_encryption = get_bucket_encryption(s3_client, name)
+ except (ClientError, BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to get bucket encryption")
+ elif encryption is not None:
+ module.fail_json(msg="Using bucket encryption requires botocore version >= 1.7.41")
+
+ if encryption is not None:
+ current_encryption_algorithm = current_encryption.get('SSEAlgorithm') if current_encryption else None
+ current_encryption_key = current_encryption.get('KMSMasterKeyID') if current_encryption else None
+ if encryption == 'none' and current_encryption_algorithm is not None:
+ try:
+ delete_bucket_encryption(s3_client, name)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to delete bucket encryption")
+ current_encryption = wait_encryption_is_applied(module, s3_client, name, None)
+ changed = True
+ elif encryption != 'none' and (encryption != current_encryption_algorithm) or (encryption == 'aws:kms' and current_encryption_key != encryption_key_id):
+ expected_encryption = {'SSEAlgorithm': encryption}
+ if encryption == 'aws:kms' and encryption_key_id is not None:
+ expected_encryption.update({'KMSMasterKeyID': encryption_key_id})
+ try:
+ put_bucket_encryption(s3_client, name, expected_encryption)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to set bucket encryption")
+ current_encryption = wait_encryption_is_applied(module, s3_client, name, expected_encryption)
+ changed = True
+
+ result['encryption'] = current_encryption
+
+ module.exit_json(changed=changed, name=name, **result)
+
+
+def bucket_exists(s3_client, bucket_name):
+ # head_bucket appeared to be really inconsistent, so we use list_buckets instead,
+ # and loop over all the buckets, even if we know it's less performant :(
+ all_buckets = s3_client.list_buckets(Bucket=bucket_name)['Buckets']
+ return any(bucket['Name'] == bucket_name for bucket in all_buckets)
+
+
+@AWSRetry.exponential_backoff(max_delay=120)
+def create_bucket(s3_client, bucket_name, location):
+ try:
+ configuration = {}
+ if location not in ('us-east-1', None):
+ configuration['LocationConstraint'] = location
+ if len(configuration) > 0:
+ s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration=configuration)
+ else:
+ s3_client.create_bucket(Bucket=bucket_name)
+ return True
+ except ClientError as e:
+ error_code = e.response['Error']['Code']
+ if error_code == 'BucketAlreadyOwnedByYou':
+ # We should never get there since we check the bucket presence before calling the create_or_update_bucket
+ # method. However, the AWS Api sometimes fails to report bucket presence, so we catch this exception
+ return False
+ else:
+ raise e
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def put_bucket_tagging(s3_client, bucket_name, tags):
+ s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={'TagSet': ansible_dict_to_boto3_tag_list(tags)})
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def put_bucket_policy(s3_client, bucket_name, policy):
+ s3_client.put_bucket_policy(Bucket=bucket_name, Policy=json.dumps(policy))
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def delete_bucket_policy(s3_client, bucket_name):
+ s3_client.delete_bucket_policy(Bucket=bucket_name)
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def get_bucket_policy(s3_client, bucket_name):
+ try:
+ current_policy = json.loads(s3_client.get_bucket_policy(Bucket=bucket_name).get('Policy'))
+ except ClientError as e:
+ if e.response['Error']['Code'] == 'NoSuchBucketPolicy':
+ current_policy = None
+ else:
+ raise e
+ return current_policy
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def put_bucket_request_payment(s3_client, bucket_name, payer):
+ s3_client.put_bucket_request_payment(Bucket=bucket_name, RequestPaymentConfiguration={'Payer': payer})
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def get_bucket_request_payment(s3_client, bucket_name):
+ return s3_client.get_bucket_request_payment(Bucket=bucket_name).get('Payer')
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def get_bucket_versioning(s3_client, bucket_name):
+ return s3_client.get_bucket_versioning(Bucket=bucket_name)
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def put_bucket_versioning(s3_client, bucket_name, required_versioning):
+ s3_client.put_bucket_versioning(Bucket=bucket_name, VersioningConfiguration={'Status': required_versioning})
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def get_bucket_encryption(s3_client, bucket_name):
+ try:
+ result = s3_client.get_bucket_encryption(Bucket=bucket_name)
+ return result.get('ServerSideEncryptionConfiguration', {}).get('Rules', [])[0].get('ApplyServerSideEncryptionByDefault')
+ except ClientError as e:
+ if e.response['Error']['Code'] == 'ServerSideEncryptionConfigurationNotFoundError':
+ return None
+ else:
+ raise e
+ except (IndexError, KeyError):
+ return None
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def put_bucket_encryption(s3_client, bucket_name, encryption):
+ server_side_encryption_configuration = {'Rules': [{'ApplyServerSideEncryptionByDefault': encryption}]}
+ s3_client.put_bucket_encryption(Bucket=bucket_name, ServerSideEncryptionConfiguration=server_side_encryption_configuration)
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def delete_bucket_tagging(s3_client, bucket_name):
+ s3_client.delete_bucket_tagging(Bucket=bucket_name)
+
+
+@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket'])
+def delete_bucket_encryption(s3_client, bucket_name):
+ s3_client.delete_bucket_encryption(Bucket=bucket_name)
+
+
+@AWSRetry.exponential_backoff(max_delay=120)
+def delete_bucket(s3_client, bucket_name):
+ try:
+ s3_client.delete_bucket(Bucket=bucket_name)
+ except ClientError as e:
+ if e.response['Error']['Code'] == 'NoSuchBucket':
+ # This means bucket should have been in a deleting state when we checked it existence
+ # We just ignore the error
+ pass
+ else:
+ raise e
+
+
+def wait_policy_is_applied(module, s3_client, bucket_name, expected_policy, should_fail=True):
+ for dummy in range(0, 12):
+ try:
+ current_policy = get_bucket_policy(s3_client, bucket_name)
+ except (ClientError, BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to get bucket policy")
+
+ if compare_policies(current_policy, expected_policy):
+ time.sleep(5)
+ else:
+ return current_policy
+ if should_fail:
+ module.fail_json(msg="Bucket policy failed to apply in the expected time")
+ else:
+ return None
+
+
+def wait_payer_is_applied(module, s3_client, bucket_name, expected_payer, should_fail=True):
+ for dummy in range(0, 12):
+ try:
+ requester_pays_status = get_bucket_request_payment(s3_client, bucket_name)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to get bucket request payment")
+ if requester_pays_status != expected_payer:
+ time.sleep(5)
+ else:
+ return requester_pays_status
+ if should_fail:
+ module.fail_json(msg="Bucket request payment failed to apply in the expected time")
+ else:
+ return None
+
+
+def wait_encryption_is_applied(module, s3_client, bucket_name, expected_encryption):
+ for dummy in range(0, 12):
+ try:
+ encryption = get_bucket_encryption(s3_client, bucket_name)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to get updated encryption for bucket")
+ if encryption != expected_encryption:
+ time.sleep(5)
+ else:
+ return encryption
+ module.fail_json(msg="Bucket encryption failed to apply in the expected time")
+
+
+def wait_versioning_is_applied(module, s3_client, bucket_name, required_versioning):
+ for dummy in range(0, 24):
+ try:
+ versioning_status = get_bucket_versioning(s3_client, bucket_name)
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to get updated versioning for bucket")
+ if versioning_status.get('Status') != required_versioning:
+ time.sleep(8)
+ else:
+ return versioning_status
+ module.fail_json(msg="Bucket versioning failed to apply in the expected time")
+
+
+def wait_tags_are_applied(module, s3_client, bucket_name, expected_tags_dict):
+ for dummy in range(0, 12):
+ try:
+ current_tags_dict = get_current_bucket_tags_dict(s3_client, bucket_name)
+ except (ClientError, BotoCoreError) as e:
+ module.fail_json_aws(e, msg="Failed to get bucket policy")
+ if current_tags_dict != expected_tags_dict:
+ time.sleep(5)
+ else:
+ return current_tags_dict
+ module.fail_json(msg="Bucket tags failed to apply in the expected time")
+
+
+def get_current_bucket_tags_dict(s3_client, bucket_name):
+ try:
+ current_tags = s3_client.get_bucket_tagging(Bucket=bucket_name).get('TagSet')
+ except ClientError as e:
+ if e.response['Error']['Code'] == 'NoSuchTagSet':
+ return {}
+ raise e
+
+ return boto3_tag_list_to_ansible_dict(current_tags)
+
+
+def paginated_list(s3_client, **pagination_params):
+ pg = s3_client.get_paginator('list_objects_v2')
+ for page in pg.paginate(**pagination_params):
+ yield [data['Key'] for data in page.get('Contents', [])]
+
+
+def paginated_versions_list(s3_client, **pagination_params):
+ try:
+ pg = s3_client.get_paginator('list_object_versions')
+ for page in pg.paginate(**pagination_params):
+ # We have to merge the Versions and DeleteMarker lists here, as DeleteMarkers can still prevent a bucket deletion
+ yield [(data['Key'], data['VersionId']) for data in (page.get('Versions', []) + page.get('DeleteMarkers', []))]
+ except is_boto3_error_code('NoSuchBucket'):
+ yield []
+
+
+def destroy_bucket(s3_client, module):
+
+ force = module.params.get("force")
+ name = module.params.get("name")
+ try:
+ bucket_is_present = bucket_exists(s3_client, name)
+ except EndpointConnectionError as e:
+ module.fail_json_aws(e, msg="Invalid endpoint provided: %s" % to_text(e))
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to check bucket presence")
+
+ if not bucket_is_present:
+ module.exit_json(changed=False)
+
+ if force:
+ # if there are contents then we need to delete them (including versions) before we can delete the bucket
+ try:
+ for key_version_pairs in paginated_versions_list(s3_client, Bucket=name):
+ formatted_keys = [{'Key': key, 'VersionId': version} for key, version in key_version_pairs]
+ for fk in formatted_keys:
+ # remove VersionId from cases where they are `None` so that
+ # unversioned objects are deleted using `DeleteObject`
+ # rather than `DeleteObjectVersion`, improving backwards
+ # compatibility with older IAM policies.
+ if not fk.get('VersionId'):
+ fk.pop('VersionId')
+
+ if formatted_keys:
+ resp = s3_client.delete_objects(Bucket=name, Delete={'Objects': formatted_keys})
+ if resp.get('Errors'):
+ module.fail_json(
+ msg='Could not empty bucket before deleting. Could not delete objects: {0}'.format(
+ ', '.join([k['Key'] for k in resp['Errors']])
+ ),
+ errors=resp['Errors'], response=resp
+ )
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed while deleting bucket")
+
+ try:
+ delete_bucket(s3_client, name)
+ s3_client.get_waiter('bucket_not_exists').wait(Bucket=name, WaiterConfig=dict(Delay=5, MaxAttempts=60))
+ except WaiterError as e:
+ module.fail_json_aws(e, msg='An error occurred waiting for the bucket to be deleted.')
+ except (BotoCoreError, ClientError) as e:
+ module.fail_json_aws(e, msg="Failed to delete bucket")
+
+ module.exit_json(changed=True)
+
+
+def is_fakes3(s3_url):
+ """ Return True if s3_url has scheme fakes3:// """
+ if s3_url is not None:
+ return urlparse(s3_url).scheme in ('fakes3', 'fakes3s')
+ else:
+ return False
+
+
+def get_s3_client(module, aws_connect_kwargs, location, ceph, s3_url):
+ if s3_url and ceph: # TODO - test this
+ ceph = urlparse(s3_url)
+ params = dict(module=module, conn_type='client', resource='s3', use_ssl=ceph.scheme == 'https', region=location, endpoint=s3_url, **aws_connect_kwargs)
+ elif is_fakes3(s3_url):
+ fakes3 = urlparse(s3_url)
+ port = fakes3.port
+ if fakes3.scheme == 'fakes3s':
+ protocol = "https"
+ if port is None:
+ port = 443
+ else:
+ protocol = "http"
+ if port is None:
+ port = 80
+ params = dict(module=module, conn_type='client', resource='s3', region=location,
+ endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)),
+ use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs)
+ else:
+ params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=s3_url, **aws_connect_kwargs)
+ return boto3_conn(**params)
+
+
+def main():
+
+ argument_spec = ec2_argument_spec()
+ argument_spec.update(
+ dict(
+ force=dict(default=False, type='bool'),
+ policy=dict(type='json'),
+ name=dict(required=True),
+ requester_pays=dict(default=False, type='bool'),
+ s3_url=dict(aliases=['S3_URL']),
+ state=dict(default='present', choices=['present', 'absent']),
+ tags=dict(type='dict'),
+ purge_tags=dict(type='bool', default=True),
+ versioning=dict(type='bool'),
+ ceph=dict(default=False, type='bool'),
+ encryption=dict(choices=['none', 'AES256', 'aws:kms']),
+ encryption_key_id=dict()
+ )
+ )
+
+ module = AnsibleAWSModule(
+ argument_spec=argument_spec,
+ )
+
+ region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True)
+
+ if region in ('us-east-1', '', None):
+ # default to US Standard region
+ location = 'us-east-1'
+ else:
+ # Boto uses symbolic names for locations but region strings will
+ # actually work fine for everything except us-east-1 (US Standard)
+ location = region
+
+ s3_url = module.params.get('s3_url')
+ ceph = module.params.get('ceph')
+
+ # allow eucarc environment variables to be used if ansible vars aren't set
+ if not s3_url and 'S3_URL' in os.environ:
+ s3_url = os.environ['S3_URL']
+
+ if ceph and not s3_url:
+ module.fail_json(msg='ceph flavour requires s3_url')
+
+ # Look at s3_url and tweak connection settings
+ # if connecting to Ceph RGW, Walrus or fakes3
+ if s3_url:
+ for key in ['validate_certs', 'security_token', 'profile_name']:
+ aws_connect_kwargs.pop(key, None)
+ s3_client = get_s3_client(module, aws_connect_kwargs, location, ceph, s3_url)
+
+ if s3_client is None: # this should never happen
+ module.fail_json(msg='Unknown error, failed to create s3 connection, no information from boto.')
+
+ state = module.params.get("state")
+ encryption = module.params.get("encryption")
+ encryption_key_id = module.params.get("encryption_key_id")
+
+ # Parameter validation
+ if encryption_key_id is not None and encryption is None:
+ module.fail_json(msg="You must specify encryption parameter along with encryption_key_id.")
+ elif encryption_key_id is not None and encryption != 'aws:kms':
+ module.fail_json(msg="Only 'aws:kms' is a valid option for encryption parameter when you specify encryption_key_id.")
+
+ if state == 'present':
+ create_or_update_bucket(s3_client, module, location)
+ elif state == 'absent':
+ destroy_bucket(s3_client, module)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/sefcontext.py b/test/support/integration/plugins/modules/sefcontext.py
new file mode 100644
index 00000000..33e3fd2e
--- /dev/null
+++ b/test/support/integration/plugins/modules/sefcontext.py
@@ -0,0 +1,298 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2016, Dag Wieers (@dagwieers) <dag@wieers.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: sefcontext
+short_description: Manages SELinux file context mapping definitions
+description:
+- Manages SELinux file context mapping definitions.
+- Similar to the C(semanage fcontext) command.
+version_added: '2.2'
+options:
+ target:
+ description:
+ - Target path (expression).
+ type: str
+ required: yes
+ aliases: [ path ]
+ ftype:
+ description:
+ - The file type that should have SELinux contexts applied.
+ - "The following file type options are available:"
+ - C(a) for all files,
+ - C(b) for block devices,
+ - C(c) for character devices,
+ - C(d) for directories,
+ - C(f) for regular files,
+ - C(l) for symbolic links,
+ - C(p) for named pipes,
+ - C(s) for socket files.
+ type: str
+ choices: [ a, b, c, d, f, l, p, s ]
+ default: a
+ setype:
+ description:
+ - SELinux type for the specified target.
+ type: str
+ required: yes
+ seuser:
+ description:
+ - SELinux user for the specified target.
+ type: str
+ selevel:
+ description:
+ - SELinux range for the specified target.
+ type: str
+ aliases: [ serange ]
+ state:
+ description:
+ - Whether the SELinux file context must be C(absent) or C(present).
+ type: str
+ choices: [ absent, present ]
+ default: present
+ reload:
+ description:
+ - Reload SELinux policy after commit.
+ - Note that this does not apply SELinux file contexts to existing files.
+ type: bool
+ default: yes
+ ignore_selinux_state:
+ description:
+ - Useful for scenarios (chrooted environment) that you can't get the real SELinux state.
+ type: bool
+ default: no
+ version_added: '2.8'
+notes:
+- The changes are persistent across reboots.
+- The M(sefcontext) module does not modify existing files to the new
+ SELinux context(s), so it is advisable to first create the SELinux
+ file contexts before creating files, or run C(restorecon) manually
+ for the existing files that require the new SELinux file contexts.
+- Not applying SELinux fcontexts to existing files is a deliberate
+ decision as it would be unclear what reported changes would entail
+ to, and there's no guarantee that applying SELinux fcontext does
+ not pick up other unrelated prior changes.
+requirements:
+- libselinux-python
+- policycoreutils-python
+author:
+- Dag Wieers (@dagwieers)
+'''
+
+EXAMPLES = r'''
+- name: Allow apache to modify files in /srv/git_repos
+ sefcontext:
+ target: '/srv/git_repos(/.*)?'
+ setype: httpd_git_rw_content_t
+ state: present
+
+- name: Apply new SELinux file context to filesystem
+ command: restorecon -irv /srv/git_repos
+'''
+
+RETURN = r'''
+# Default return values
+'''
+
+import traceback
+
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+from ansible.module_utils._text import to_native
+
+SELINUX_IMP_ERR = None
+try:
+ import selinux
+ HAVE_SELINUX = True
+except ImportError:
+ SELINUX_IMP_ERR = traceback.format_exc()
+ HAVE_SELINUX = False
+
+SEOBJECT_IMP_ERR = None
+try:
+ import seobject
+ HAVE_SEOBJECT = True
+except ImportError:
+ SEOBJECT_IMP_ERR = traceback.format_exc()
+ HAVE_SEOBJECT = False
+
+# Add missing entries (backward compatible)
+if HAVE_SEOBJECT:
+ seobject.file_types.update(
+ a=seobject.SEMANAGE_FCONTEXT_ALL,
+ b=seobject.SEMANAGE_FCONTEXT_BLOCK,
+ c=seobject.SEMANAGE_FCONTEXT_CHAR,
+ d=seobject.SEMANAGE_FCONTEXT_DIR,
+ f=seobject.SEMANAGE_FCONTEXT_REG,
+ l=seobject.SEMANAGE_FCONTEXT_LINK,
+ p=seobject.SEMANAGE_FCONTEXT_PIPE,
+ s=seobject.SEMANAGE_FCONTEXT_SOCK,
+ )
+
+# Make backward compatible
+option_to_file_type_str = dict(
+ a='all files',
+ b='block device',
+ c='character device',
+ d='directory',
+ f='regular file',
+ l='symbolic link',
+ p='named pipe',
+ s='socket',
+)
+
+
+def get_runtime_status(ignore_selinux_state=False):
+ return True if ignore_selinux_state is True else selinux.is_selinux_enabled()
+
+
+def semanage_fcontext_exists(sefcontext, target, ftype):
+ ''' Get the SELinux file context mapping definition from policy. Return None if it does not exist. '''
+
+ # Beware that records comprise of a string representation of the file_type
+ record = (target, option_to_file_type_str[ftype])
+ records = sefcontext.get_all()
+ try:
+ return records[record]
+ except KeyError:
+ return None
+
+
+def semanage_fcontext_modify(module, result, target, ftype, setype, do_reload, serange, seuser, sestore=''):
+ ''' Add or modify SELinux file context mapping definition to the policy. '''
+
+ changed = False
+ prepared_diff = ''
+
+ try:
+ sefcontext = seobject.fcontextRecords(sestore)
+ sefcontext.set_reload(do_reload)
+ exists = semanage_fcontext_exists(sefcontext, target, ftype)
+ if exists:
+ # Modify existing entry
+ orig_seuser, orig_serole, orig_setype, orig_serange = exists
+
+ if seuser is None:
+ seuser = orig_seuser
+ if serange is None:
+ serange = orig_serange
+
+ if setype != orig_setype or seuser != orig_seuser or serange != orig_serange:
+ if not module.check_mode:
+ sefcontext.modify(target, setype, ftype, serange, seuser)
+ changed = True
+
+ if module._diff:
+ prepared_diff += '# Change to semanage file context mappings\n'
+ prepared_diff += '-%s %s %s:%s:%s:%s\n' % (target, ftype, orig_seuser, orig_serole, orig_setype, orig_serange)
+ prepared_diff += '+%s %s %s:%s:%s:%s\n' % (target, ftype, seuser, orig_serole, setype, serange)
+ else:
+ # Add missing entry
+ if seuser is None:
+ seuser = 'system_u'
+ if serange is None:
+ serange = 's0'
+
+ if not module.check_mode:
+ sefcontext.add(target, setype, ftype, serange, seuser)
+ changed = True
+
+ if module._diff:
+ prepared_diff += '# Addition to semanage file context mappings\n'
+ prepared_diff += '+%s %s %s:%s:%s:%s\n' % (target, ftype, seuser, 'object_r', setype, serange)
+
+ except Exception as e:
+ module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e)))
+
+ if module._diff and prepared_diff:
+ result['diff'] = dict(prepared=prepared_diff)
+
+ module.exit_json(changed=changed, seuser=seuser, serange=serange, **result)
+
+
+def semanage_fcontext_delete(module, result, target, ftype, do_reload, sestore=''):
+ ''' Delete SELinux file context mapping definition from the policy. '''
+
+ changed = False
+ prepared_diff = ''
+
+ try:
+ sefcontext = seobject.fcontextRecords(sestore)
+ sefcontext.set_reload(do_reload)
+ exists = semanage_fcontext_exists(sefcontext, target, ftype)
+ if exists:
+ # Remove existing entry
+ orig_seuser, orig_serole, orig_setype, orig_serange = exists
+
+ if not module.check_mode:
+ sefcontext.delete(target, ftype)
+ changed = True
+
+ if module._diff:
+ prepared_diff += '# Deletion to semanage file context mappings\n'
+ prepared_diff += '-%s %s %s:%s:%s:%s\n' % (target, ftype, exists[0], exists[1], exists[2], exists[3])
+
+ except Exception as e:
+ module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e)))
+
+ if module._diff and prepared_diff:
+ result['diff'] = dict(prepared=prepared_diff)
+
+ module.exit_json(changed=changed, **result)
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ ignore_selinux_state=dict(type='bool', default=False),
+ target=dict(type='str', required=True, aliases=['path']),
+ ftype=dict(type='str', default='a', choices=option_to_file_type_str.keys()),
+ setype=dict(type='str', required=True),
+ seuser=dict(type='str'),
+ selevel=dict(type='str', aliases=['serange']),
+ state=dict(type='str', default='present', choices=['absent', 'present']),
+ reload=dict(type='bool', default=True),
+ ),
+ supports_check_mode=True,
+ )
+ if not HAVE_SELINUX:
+ module.fail_json(msg=missing_required_lib("libselinux-python"), exception=SELINUX_IMP_ERR)
+
+ if not HAVE_SEOBJECT:
+ module.fail_json(msg=missing_required_lib("policycoreutils-python"), exception=SEOBJECT_IMP_ERR)
+
+ ignore_selinux_state = module.params['ignore_selinux_state']
+
+ if not get_runtime_status(ignore_selinux_state):
+ module.fail_json(msg="SELinux is disabled on this host.")
+
+ target = module.params['target']
+ ftype = module.params['ftype']
+ setype = module.params['setype']
+ seuser = module.params['seuser']
+ serange = module.params['selevel']
+ state = module.params['state']
+ do_reload = module.params['reload']
+
+ result = dict(target=target, ftype=ftype, setype=setype, state=state)
+
+ if state == 'present':
+ semanage_fcontext_modify(module, result, target, ftype, setype, do_reload, serange, seuser)
+ elif state == 'absent':
+ semanage_fcontext_delete(module, result, target, ftype, do_reload)
+ else:
+ module.fail_json(msg='Invalid value of argument "state": {0}'.format(state))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/selogin.py b/test/support/integration/plugins/modules/selogin.py
new file mode 100644
index 00000000..6429ef36
--- /dev/null
+++ b/test/support/integration/plugins/modules/selogin.py
@@ -0,0 +1,260 @@
+#!/usr/bin/python
+
+# (c) 2017, Petr Lautrbach <plautrba@redhat.com>
+# Based on seport.py module (c) 2014, Dan Keder <dan.keder@gmail.com>
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = '''
+---
+module: selogin
+short_description: Manages linux user to SELinux user mapping
+description:
+ - Manages linux user to SELinux user mapping
+version_added: "2.8"
+options:
+ login:
+ description:
+ - a Linux user
+ required: true
+ seuser:
+ description:
+ - SELinux user name
+ required: true
+ selevel:
+ aliases: [ serange ]
+ description:
+ - MLS/MCS Security Range (MLS/MCS Systems only) SELinux Range for SELinux login mapping defaults to the SELinux user record range.
+ default: s0
+ state:
+ description:
+ - Desired mapping value.
+ required: true
+ default: present
+ choices: [ 'present', 'absent' ]
+ reload:
+ description:
+ - Reload SELinux policy after commit.
+ default: yes
+ ignore_selinux_state:
+ description:
+ - Run independent of selinux runtime state
+ type: bool
+ default: false
+notes:
+ - The changes are persistent across reboots
+ - Not tested on any debian based system
+requirements: [ 'libselinux', 'policycoreutils' ]
+author:
+- Dan Keder (@dankeder)
+- Petr Lautrbach (@bachradsusi)
+- James Cassell (@jamescassell)
+'''
+
+EXAMPLES = '''
+# Modify the default user on the system to the guest_u user
+- selogin:
+ login: __default__
+ seuser: guest_u
+ state: present
+
+# Assign gijoe user on an MLS machine a range and to the staff_u user
+- selogin:
+ login: gijoe
+ seuser: staff_u
+ serange: SystemLow-Secret
+ state: present
+
+# Assign all users in the engineering group to the staff_u user
+- selogin:
+ login: '%engineering'
+ seuser: staff_u
+ state: present
+'''
+
+RETURN = r'''
+# Default return values
+'''
+
+
+import traceback
+
+SELINUX_IMP_ERR = None
+try:
+ import selinux
+ HAVE_SELINUX = True
+except ImportError:
+ SELINUX_IMP_ERR = traceback.format_exc()
+ HAVE_SELINUX = False
+
+SEOBJECT_IMP_ERR = None
+try:
+ import seobject
+ HAVE_SEOBJECT = True
+except ImportError:
+ SEOBJECT_IMP_ERR = traceback.format_exc()
+ HAVE_SEOBJECT = False
+
+
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+from ansible.module_utils._text import to_native
+
+
+def semanage_login_add(module, login, seuser, do_reload, serange='s0', sestore=''):
+ """ Add linux user to SELinux user mapping
+
+ :type module: AnsibleModule
+ :param module: Ansible module
+
+ :type login: str
+ :param login: a Linux User or a Linux group if it begins with %
+
+ :type seuser: str
+ :param proto: An SELinux user ('__default__', 'unconfined_u', 'staff_u', ...), see 'semanage login -l'
+
+ :type serange: str
+ :param serange: SELinux MLS/MCS range (defaults to 's0')
+
+ :type do_reload: bool
+ :param do_reload: Whether to reload SELinux policy after commit
+
+ :type sestore: str
+ :param sestore: SELinux store
+
+ :rtype: bool
+ :return: True if the policy was changed, otherwise False
+ """
+ try:
+ selogin = seobject.loginRecords(sestore)
+ selogin.set_reload(do_reload)
+ change = False
+ all_logins = selogin.get_all()
+ # module.fail_json(msg="%s: %s %s" % (all_logins, login, sestore))
+ # for local_login in all_logins:
+ if login not in all_logins.keys():
+ change = True
+ if not module.check_mode:
+ selogin.add(login, seuser, serange)
+ else:
+ if all_logins[login][0] != seuser or all_logins[login][1] != serange:
+ change = True
+ if not module.check_mode:
+ selogin.modify(login, seuser, serange)
+
+ except (ValueError, KeyError, OSError, RuntimeError) as e:
+ module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e)), exception=traceback.format_exc())
+
+ return change
+
+
+def semanage_login_del(module, login, seuser, do_reload, sestore=''):
+ """ Delete linux user to SELinux user mapping
+
+ :type module: AnsibleModule
+ :param module: Ansible module
+
+ :type login: str
+ :param login: a Linux User or a Linux group if it begins with %
+
+ :type seuser: str
+ :param proto: An SELinux user ('__default__', 'unconfined_u', 'staff_u', ...), see 'semanage login -l'
+
+ :type do_reload: bool
+ :param do_reload: Whether to reload SELinux policy after commit
+
+ :type sestore: str
+ :param sestore: SELinux store
+
+ :rtype: bool
+ :return: True if the policy was changed, otherwise False
+ """
+ try:
+ selogin = seobject.loginRecords(sestore)
+ selogin.set_reload(do_reload)
+ change = False
+ all_logins = selogin.get_all()
+ # module.fail_json(msg="%s: %s %s" % (all_logins, login, sestore))
+ if login in all_logins.keys():
+ change = True
+ if not module.check_mode:
+ selogin.delete(login)
+
+ except (ValueError, KeyError, OSError, RuntimeError) as e:
+ module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e)), exception=traceback.format_exc())
+
+ return change
+
+
+def get_runtime_status(ignore_selinux_state=False):
+ return True if ignore_selinux_state is True else selinux.is_selinux_enabled()
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ ignore_selinux_state=dict(type='bool', default=False),
+ login=dict(type='str', required=True),
+ seuser=dict(type='str'),
+ selevel=dict(type='str', aliases=['serange'], default='s0'),
+ state=dict(type='str', default='present', choices=['absent', 'present']),
+ reload=dict(type='bool', default=True),
+ ),
+ required_if=[
+ ["state", "present", ["seuser"]]
+ ],
+ supports_check_mode=True
+ )
+ if not HAVE_SELINUX:
+ module.fail_json(msg=missing_required_lib("libselinux"), exception=SELINUX_IMP_ERR)
+
+ if not HAVE_SEOBJECT:
+ module.fail_json(msg=missing_required_lib("seobject from policycoreutils"), exception=SEOBJECT_IMP_ERR)
+
+ ignore_selinux_state = module.params['ignore_selinux_state']
+
+ if not get_runtime_status(ignore_selinux_state):
+ module.fail_json(msg="SELinux is disabled on this host.")
+
+ login = module.params['login']
+ seuser = module.params['seuser']
+ serange = module.params['selevel']
+ state = module.params['state']
+ do_reload = module.params['reload']
+
+ result = {
+ 'login': login,
+ 'seuser': seuser,
+ 'serange': serange,
+ 'state': state,
+ }
+
+ if state == 'present':
+ result['changed'] = semanage_login_add(module, login, seuser, do_reload, serange)
+ elif state == 'absent':
+ result['changed'] = semanage_login_del(module, login, seuser, do_reload)
+ else:
+ module.fail_json(msg='Invalid value of argument "state": {0}'.format(state))
+
+ module.exit_json(**result)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/synchronize.py b/test/support/integration/plugins/modules/synchronize.py
new file mode 100644
index 00000000..e4c520b7
--- /dev/null
+++ b/test/support/integration/plugins/modules/synchronize.py
@@ -0,0 +1,618 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2012-2013, Timothy Appnel <tim@appnel.com>
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: synchronize
+version_added: "1.4"
+short_description: A wrapper around rsync to make common tasks in your playbooks quick and easy
+description:
+ - C(synchronize) is a wrapper around rsync to make common tasks in your playbooks quick and easy.
+ - It is run and originates on the local host where Ansible is being run.
+ - Of course, you could just use the C(command) action to call rsync yourself, but you also have to add a fair number of
+ boilerplate options and host facts.
+ - This module is not intended to provide access to the full power of rsync, but does make the most common
+ invocations easier to implement. You `still` may need to call rsync directly via C(command) or C(shell) depending on your use case.
+options:
+ src:
+ description:
+ - Path on the source host that will be synchronized to the destination.
+ - The path can be absolute or relative.
+ type: str
+ required: true
+ dest:
+ description:
+ - Path on the destination host that will be synchronized from the source.
+ - The path can be absolute or relative.
+ type: str
+ required: true
+ dest_port:
+ description:
+ - Port number for ssh on the destination host.
+ - Prior to Ansible 2.0, the ansible_ssh_port inventory var took precedence over this value.
+ - This parameter defaults to the value of C(ansible_ssh_port) or C(ansible_port),
+ the C(remote_port) config setting or the value from ssh client configuration
+ if none of the former have been set.
+ type: int
+ version_added: "1.5"
+ mode:
+ description:
+ - Specify the direction of the synchronization.
+ - In push mode the localhost or delegate is the source.
+ - In pull mode the remote host in context is the source.
+ type: str
+ choices: [ pull, push ]
+ default: push
+ archive:
+ description:
+ - Mirrors the rsync archive flag, enables recursive, links, perms, times, owner, group flags and -D.
+ type: bool
+ default: yes
+ checksum:
+ description:
+ - Skip based on checksum, rather than mod-time & size; Note that that "archive" option is still enabled by default - the "checksum" option will
+ not disable it.
+ type: bool
+ default: no
+ version_added: "1.6"
+ compress:
+ description:
+ - Compress file data during the transfer.
+ - In most cases, leave this enabled unless it causes problems.
+ type: bool
+ default: yes
+ version_added: "1.7"
+ existing_only:
+ description:
+ - Skip creating new files on receiver.
+ type: bool
+ default: no
+ version_added: "1.5"
+ delete:
+ description:
+ - Delete files in C(dest) that don't exist (after transfer, not before) in the C(src) path.
+ - This option requires C(recursive=yes).
+ - This option ignores excluded files and behaves like the rsync opt --delete-excluded.
+ type: bool
+ default: no
+ dirs:
+ description:
+ - Transfer directories without recursing.
+ type: bool
+ default: no
+ recursive:
+ description:
+ - Recurse into directories.
+ - This parameter defaults to the value of the archive option.
+ type: bool
+ links:
+ description:
+ - Copy symlinks as symlinks.
+ - This parameter defaults to the value of the archive option.
+ type: bool
+ copy_links:
+ description:
+ - Copy symlinks as the item that they point to (the referent) is copied, rather than the symlink.
+ type: bool
+ default: no
+ perms:
+ description:
+ - Preserve permissions.
+ - This parameter defaults to the value of the archive option.
+ type: bool
+ times:
+ description:
+ - Preserve modification times.
+ - This parameter defaults to the value of the archive option.
+ type: bool
+ owner:
+ description:
+ - Preserve owner (super user only).
+ - This parameter defaults to the value of the archive option.
+ type: bool
+ group:
+ description:
+ - Preserve group.
+ - This parameter defaults to the value of the archive option.
+ type: bool
+ rsync_path:
+ description:
+ - Specify the rsync command to run on the remote host. See C(--rsync-path) on the rsync man page.
+ - To specify the rsync command to run on the local host, you need to set this your task var C(ansible_rsync_path).
+ type: str
+ rsync_timeout:
+ description:
+ - Specify a C(--timeout) for the rsync command in seconds.
+ type: int
+ default: 0
+ set_remote_user:
+ description:
+ - Put user@ for the remote paths.
+ - If you have a custom ssh config to define the remote user for a host
+ that does not match the inventory user, you should set this parameter to C(no).
+ type: bool
+ default: yes
+ use_ssh_args:
+ description:
+ - Use the ssh_args specified in ansible.cfg.
+ type: bool
+ default: no
+ version_added: "2.0"
+ rsync_opts:
+ description:
+ - Specify additional rsync options by passing in an array.
+ - Note that an empty string in C(rsync_opts) will end up transfer the current working directory.
+ type: list
+ default:
+ version_added: "1.6"
+ partial:
+ description:
+ - Tells rsync to keep the partial file which should make a subsequent transfer of the rest of the file much faster.
+ type: bool
+ default: no
+ version_added: "2.0"
+ verify_host:
+ description:
+ - Verify destination host key.
+ type: bool
+ default: no
+ version_added: "2.0"
+ private_key:
+ description:
+ - Specify the private key to use for SSH-based rsync connections (e.g. C(~/.ssh/id_rsa)).
+ type: path
+ version_added: "1.6"
+ link_dest:
+ description:
+ - Add a destination to hard link against during the rsync.
+ type: list
+ default:
+ version_added: "2.5"
+notes:
+ - rsync must be installed on both the local and remote host.
+ - For the C(synchronize) module, the "local host" is the host `the synchronize task originates on`, and the "destination host" is the host
+ `synchronize is connecting to`.
+ - The "local host" can be changed to a different host by using `delegate_to`. This enables copying between two remote hosts or entirely on one
+ remote machine.
+ - >
+ The user and permissions for the synchronize `src` are those of the user running the Ansible task on the local host (or the remote_user for a
+ delegate_to host when delegate_to is used).
+ - The user and permissions for the synchronize `dest` are those of the `remote_user` on the destination host or the `become_user` if `become=yes` is active.
+ - In Ansible 2.0 a bug in the synchronize module made become occur on the "local host". This was fixed in Ansible 2.0.1.
+ - Currently, synchronize is limited to elevating permissions via passwordless sudo. This is because rsync itself is connecting to the remote machine
+ and rsync doesn't give us a way to pass sudo credentials in.
+ - Currently there are only a few connection types which support synchronize (ssh, paramiko, local, and docker) because a sync strategy has been
+ determined for those connection types. Note that the connection for these must not need a password as rsync itself is making the connection and
+ rsync does not provide us a way to pass a password to the connection.
+ - Expect that dest=~/x will be ~<remote_user>/x even if using sudo.
+ - Inspect the verbose output to validate the destination user/host/path are what was expected.
+ - To exclude files and directories from being synchronized, you may add C(.rsync-filter) files to the source directory.
+ - rsync daemon must be up and running with correct permission when using rsync protocol in source or destination path.
+ - The C(synchronize) module forces `--delay-updates` to avoid leaving a destination in a broken in-between state if the underlying rsync process
+ encounters an error. Those synchronizing large numbers of files that are willing to trade safety for performance should call rsync directly.
+ - link_destination is subject to the same limitations as the underlying rsync daemon. Hard links are only preserved if the relative subtrees
+ of the source and destination are the same. Attempts to hardlink into a directory that is a subdirectory of the source will be prevented.
+seealso:
+- module: copy
+- module: win_robocopy
+author:
+- Timothy Appnel (@tima)
+'''
+
+EXAMPLES = '''
+- name: Synchronization of src on the control machine to dest on the remote hosts
+ synchronize:
+ src: some/relative/path
+ dest: /some/absolute/path
+
+- name: Synchronization using rsync protocol (push)
+ synchronize:
+ src: some/relative/path/
+ dest: rsync://somehost.com/path/
+
+- name: Synchronization using rsync protocol (pull)
+ synchronize:
+ mode: pull
+ src: rsync://somehost.com/path/
+ dest: /some/absolute/path/
+
+- name: Synchronization using rsync protocol on delegate host (push)
+ synchronize:
+ src: /some/absolute/path/
+ dest: rsync://somehost.com/path/
+ delegate_to: delegate.host
+
+- name: Synchronization using rsync protocol on delegate host (pull)
+ synchronize:
+ mode: pull
+ src: rsync://somehost.com/path/
+ dest: /some/absolute/path/
+ delegate_to: delegate.host
+
+- name: Synchronization without any --archive options enabled
+ synchronize:
+ src: some/relative/path
+ dest: /some/absolute/path
+ archive: no
+
+- name: Synchronization with --archive options enabled except for --recursive
+ synchronize:
+ src: some/relative/path
+ dest: /some/absolute/path
+ recursive: no
+
+- name: Synchronization with --archive options enabled except for --times, with --checksum option enabled
+ synchronize:
+ src: some/relative/path
+ dest: /some/absolute/path
+ checksum: yes
+ times: no
+
+- name: Synchronization without --archive options enabled except use --links
+ synchronize:
+ src: some/relative/path
+ dest: /some/absolute/path
+ archive: no
+ links: yes
+
+- name: Synchronization of two paths both on the control machine
+ synchronize:
+ src: some/relative/path
+ dest: /some/absolute/path
+ delegate_to: localhost
+
+- name: Synchronization of src on the inventory host to the dest on the localhost in pull mode
+ synchronize:
+ mode: pull
+ src: some/relative/path
+ dest: /some/absolute/path
+
+- name: Synchronization of src on delegate host to dest on the current inventory host.
+ synchronize:
+ src: /first/absolute/path
+ dest: /second/absolute/path
+ delegate_to: delegate.host
+
+- name: Synchronize two directories on one remote host.
+ synchronize:
+ src: /first/absolute/path
+ dest: /second/absolute/path
+ delegate_to: "{{ inventory_hostname }}"
+
+- name: Synchronize and delete files in dest on the remote host that are not found in src of localhost.
+ synchronize:
+ src: some/relative/path
+ dest: /some/absolute/path
+ delete: yes
+ recursive: yes
+
+# This specific command is granted su privileges on the destination
+- name: Synchronize using an alternate rsync command
+ synchronize:
+ src: some/relative/path
+ dest: /some/absolute/path
+ rsync_path: su -c rsync
+
+# Example .rsync-filter file in the source directory
+# - var # exclude any path whose last part is 'var'
+# - /var # exclude any path starting with 'var' starting at the source directory
+# + /var/conf # include /var/conf even though it was previously excluded
+
+- name: Synchronize passing in extra rsync options
+ synchronize:
+ src: /tmp/helloworld
+ dest: /var/www/helloworld
+ rsync_opts:
+ - "--no-motd"
+ - "--exclude=.git"
+
+# Hardlink files if they didn't change
+- name: Use hardlinks when synchronizing filesystems
+ synchronize:
+ src: /tmp/path_a/foo.txt
+ dest: /tmp/path_b/foo.txt
+ link_dest: /tmp/path_a/
+
+# Specify the rsync binary to use on remote host and on local host
+- hosts: groupofhosts
+ vars:
+ ansible_rsync_path: /usr/gnu/bin/rsync
+
+ tasks:
+ - name: copy /tmp/localpath/ to remote location /tmp/remotepath
+ synchronize:
+ src: /tmp/localpath/
+ dest: /tmp/remotepath
+ rsync_path: /usr/gnu/bin/rsync
+'''
+
+
+import os
+import errno
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils._text import to_bytes, to_native
+from ansible.module_utils.six.moves import shlex_quote
+
+
+client_addr = None
+
+
+def substitute_controller(path):
+ global client_addr
+ if not client_addr:
+ ssh_env_string = os.environ.get('SSH_CLIENT', None)
+ try:
+ client_addr, _ = ssh_env_string.split(None, 1)
+ except AttributeError:
+ ssh_env_string = os.environ.get('SSH_CONNECTION', None)
+ try:
+ client_addr, _ = ssh_env_string.split(None, 1)
+ except AttributeError:
+ pass
+ if not client_addr:
+ raise ValueError
+
+ if path.startswith('localhost:'):
+ path = path.replace('localhost', client_addr, 1)
+ return path
+
+
+def is_rsh_needed(source, dest):
+ if source.startswith('rsync://') or dest.startswith('rsync://'):
+ return False
+ if ':' in source or ':' in dest:
+ return True
+ return False
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ src=dict(type='str', required=True),
+ dest=dict(type='str', required=True),
+ dest_port=dict(type='int'),
+ delete=dict(type='bool', default=False),
+ private_key=dict(type='path'),
+ rsync_path=dict(type='str'),
+ _local_rsync_path=dict(type='path', default='rsync'),
+ _local_rsync_password=dict(type='str', no_log=True),
+ _substitute_controller=dict(type='bool', default=False),
+ archive=dict(type='bool', default=True),
+ checksum=dict(type='bool', default=False),
+ compress=dict(type='bool', default=True),
+ existing_only=dict(type='bool', default=False),
+ dirs=dict(type='bool', default=False),
+ recursive=dict(type='bool'),
+ links=dict(type='bool'),
+ copy_links=dict(type='bool', default=False),
+ perms=dict(type='bool'),
+ times=dict(type='bool'),
+ owner=dict(type='bool'),
+ group=dict(type='bool'),
+ set_remote_user=dict(type='bool', default=True),
+ rsync_timeout=dict(type='int', default=0),
+ rsync_opts=dict(type='list', default=[]),
+ ssh_args=dict(type='str'),
+ partial=dict(type='bool', default=False),
+ verify_host=dict(type='bool', default=False),
+ mode=dict(type='str', default='push', choices=['pull', 'push']),
+ link_dest=dict(type='list')
+ ),
+ supports_check_mode=True,
+ )
+
+ if module.params['_substitute_controller']:
+ try:
+ source = substitute_controller(module.params['src'])
+ dest = substitute_controller(module.params['dest'])
+ except ValueError:
+ module.fail_json(msg='Could not determine controller hostname for rsync to send to')
+ else:
+ source = module.params['src']
+ dest = module.params['dest']
+ dest_port = module.params['dest_port']
+ delete = module.params['delete']
+ private_key = module.params['private_key']
+ rsync_path = module.params['rsync_path']
+ rsync = module.params.get('_local_rsync_path', 'rsync')
+ rsync_password = module.params.get('_local_rsync_password')
+ rsync_timeout = module.params.get('rsync_timeout', 'rsync_timeout')
+ archive = module.params['archive']
+ checksum = module.params['checksum']
+ compress = module.params['compress']
+ existing_only = module.params['existing_only']
+ dirs = module.params['dirs']
+ partial = module.params['partial']
+ # the default of these params depends on the value of archive
+ recursive = module.params['recursive']
+ links = module.params['links']
+ copy_links = module.params['copy_links']
+ perms = module.params['perms']
+ times = module.params['times']
+ owner = module.params['owner']
+ group = module.params['group']
+ rsync_opts = module.params['rsync_opts']
+ ssh_args = module.params['ssh_args']
+ verify_host = module.params['verify_host']
+ link_dest = module.params['link_dest']
+
+ if '/' not in rsync:
+ rsync = module.get_bin_path(rsync, required=True)
+
+ cmd = [rsync, '--delay-updates', '-F']
+ _sshpass_pipe = None
+ if rsync_password:
+ try:
+ module.run_command(["sshpass"])
+ except OSError:
+ module.fail_json(
+ msg="to use rsync connection with passwords, you must install the sshpass program"
+ )
+ _sshpass_pipe = os.pipe()
+ cmd = ['sshpass', '-d' + to_native(_sshpass_pipe[0], errors='surrogate_or_strict')] + cmd
+ if compress:
+ cmd.append('--compress')
+ if rsync_timeout:
+ cmd.append('--timeout=%s' % rsync_timeout)
+ if module.check_mode:
+ cmd.append('--dry-run')
+ if delete:
+ cmd.append('--delete-after')
+ if existing_only:
+ cmd.append('--existing')
+ if checksum:
+ cmd.append('--checksum')
+ if copy_links:
+ cmd.append('--copy-links')
+ if archive:
+ cmd.append('--archive')
+ if recursive is False:
+ cmd.append('--no-recursive')
+ if links is False:
+ cmd.append('--no-links')
+ if perms is False:
+ cmd.append('--no-perms')
+ if times is False:
+ cmd.append('--no-times')
+ if owner is False:
+ cmd.append('--no-owner')
+ if group is False:
+ cmd.append('--no-group')
+ else:
+ if recursive is True:
+ cmd.append('--recursive')
+ if links is True:
+ cmd.append('--links')
+ if perms is True:
+ cmd.append('--perms')
+ if times is True:
+ cmd.append('--times')
+ if owner is True:
+ cmd.append('--owner')
+ if group is True:
+ cmd.append('--group')
+ if dirs:
+ cmd.append('--dirs')
+
+ if source.startswith('rsync://') and dest.startswith('rsync://'):
+ module.fail_json(msg='either src or dest must be a localhost', rc=1)
+
+ if is_rsh_needed(source, dest):
+
+ # https://github.com/ansible/ansible/issues/15907
+ has_rsh = False
+ for rsync_opt in rsync_opts:
+ if '--rsh' in rsync_opt:
+ has_rsh = True
+ break
+
+ # if the user has not supplied an --rsh option go ahead and add ours
+ if not has_rsh:
+ ssh_cmd = [module.get_bin_path('ssh', required=True), '-S', 'none']
+ if private_key is not None:
+ ssh_cmd.extend(['-i', private_key])
+ # If the user specified a port value
+ # Note: The action plugin takes care of setting this to a port from
+ # inventory if the user didn't specify an explicit dest_port
+ if dest_port is not None:
+ ssh_cmd.extend(['-o', 'Port=%s' % dest_port])
+ if not verify_host:
+ ssh_cmd.extend(['-o', 'StrictHostKeyChecking=no', '-o', 'UserKnownHostsFile=/dev/null'])
+ ssh_cmd_str = ' '.join(shlex_quote(arg) for arg in ssh_cmd)
+ if ssh_args:
+ ssh_cmd_str += ' %s' % ssh_args
+ cmd.append('--rsh=%s' % ssh_cmd_str)
+
+ if rsync_path:
+ cmd.append('--rsync-path=%s' % rsync_path)
+
+ if rsync_opts:
+ if '' in rsync_opts:
+ module.warn('The empty string is present in rsync_opts which will cause rsync to'
+ ' transfer the current working directory. If this is intended, use "."'
+ ' instead to get rid of this warning. If this is unintended, check for'
+ ' problems in your playbook leading to empty string in rsync_opts.')
+ cmd.extend(rsync_opts)
+
+ if partial:
+ cmd.append('--partial')
+
+ if link_dest:
+ cmd.append('-H')
+ # verbose required because rsync does not believe that adding a
+ # hardlink is actually a change
+ cmd.append('-vv')
+ for x in link_dest:
+ link_path = os.path.abspath(os.path.expanduser(x))
+ destination_path = os.path.abspath(os.path.dirname(dest))
+ if destination_path.find(link_path) == 0:
+ module.fail_json(msg='Hardlinking into a subdirectory of the source would cause recursion. %s and %s' % (destination_path, dest))
+ cmd.append('--link-dest=%s' % link_path)
+
+ changed_marker = '<<CHANGED>>'
+ cmd.append('--out-format=' + changed_marker + '%i %n%L')
+
+ # expand the paths
+ if '@' not in source:
+ source = os.path.expanduser(source)
+ if '@' not in dest:
+ dest = os.path.expanduser(dest)
+
+ cmd.append(source)
+ cmd.append(dest)
+ cmdstr = ' '.join(cmd)
+
+ # If we are using password authentication, write the password into the pipe
+ if rsync_password:
+ def _write_password_to_pipe(proc):
+ os.close(_sshpass_pipe[0])
+ try:
+ os.write(_sshpass_pipe[1], to_bytes(rsync_password) + b'\n')
+ except OSError as exc:
+ # Ignore broken pipe errors if the sshpass process has exited.
+ if exc.errno != errno.EPIPE or proc.poll() is None:
+ raise
+
+ (rc, out, err) = module.run_command(
+ cmd, pass_fds=_sshpass_pipe,
+ before_communicate_callback=_write_password_to_pipe)
+ else:
+ (rc, out, err) = module.run_command(cmd)
+
+ if rc:
+ return module.fail_json(msg=err, rc=rc, cmd=cmdstr)
+
+ if link_dest:
+ # a leading period indicates no change
+ changed = (changed_marker + '.') not in out
+ else:
+ changed = changed_marker in out
+
+ out_clean = out.replace(changed_marker, '')
+ out_lines = out_clean.split('\n')
+ while '' in out_lines:
+ out_lines.remove('')
+ if module._diff:
+ diff = {'prepared': out_clean}
+ return module.exit_json(changed=changed, msg=out_clean,
+ rc=rc, cmd=cmdstr, stdout_lines=out_lines,
+ diff=diff)
+
+ return module.exit_json(changed=changed, msg=out_clean,
+ rc=rc, cmd=cmdstr, stdout_lines=out_lines)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/timezone.py b/test/support/integration/plugins/modules/timezone.py
new file mode 100644
index 00000000..b7439a12
--- /dev/null
+++ b/test/support/integration/plugins/modules/timezone.py
@@ -0,0 +1,909 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2016, Shinichi TAMURA (@tmshn)
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: timezone
+short_description: Configure timezone setting
+description:
+ - This module configures the timezone setting, both of the system clock and of the hardware clock. If you want to set up the NTP, use M(service) module.
+ - It is recommended to restart C(crond) after changing the timezone, otherwise the jobs may run at the wrong time.
+ - Several different tools are used depending on the OS/Distribution involved.
+ For Linux it can use C(timedatectl) or edit C(/etc/sysconfig/clock) or C(/etc/timezone) and C(hwclock).
+ On SmartOS, C(sm-set-timezone), for macOS, C(systemsetup), for BSD, C(/etc/localtime) is modified.
+ On AIX, C(chtz) is used.
+ - As of Ansible 2.3 support was added for SmartOS and BSDs.
+ - As of Ansible 2.4 support was added for macOS.
+ - As of Ansible 2.9 support was added for AIX 6.1+
+ - Windows and HPUX are not supported, please let us know if you find any other OS/distro in which this fails.
+version_added: "2.2"
+options:
+ name:
+ description:
+ - Name of the timezone for the system clock.
+ - Default is to keep current setting.
+ - B(At least one of name and hwclock are required.)
+ type: str
+ hwclock:
+ description:
+ - Whether the hardware clock is in UTC or in local timezone.
+ - Default is to keep current setting.
+ - Note that this option is recommended not to change and may fail
+ to configure, especially on virtual environments such as AWS.
+ - B(At least one of name and hwclock are required.)
+ - I(Only used on Linux.)
+ type: str
+ aliases: [ rtc ]
+ choices: [ local, UTC ]
+notes:
+ - On SmartOS the C(sm-set-timezone) utility (part of the smtools package) is required to set the zone timezone
+ - On AIX only Olson/tz database timezones are useable (POSIX is not supported).
+ - An OS reboot is also required on AIX for the new timezone setting to take effect.
+author:
+ - Shinichi TAMURA (@tmshn)
+ - Jasper Lievisse Adriaanse (@jasperla)
+ - Indrajit Raychaudhuri (@indrajitr)
+'''
+
+RETURN = r'''
+diff:
+ description: The differences about the given arguments.
+ returned: success
+ type: complex
+ contains:
+ before:
+ description: The values before change
+ type: dict
+ after:
+ description: The values after change
+ type: dict
+'''
+
+EXAMPLES = r'''
+- name: Set timezone to Asia/Tokyo
+ timezone:
+ name: Asia/Tokyo
+'''
+
+import errno
+import os
+import platform
+import random
+import re
+import string
+import filecmp
+
+from ansible.module_utils.basic import AnsibleModule, get_distribution
+from ansible.module_utils.six import iteritems
+
+
+class Timezone(object):
+ """This is a generic Timezone manipulation class that is subclassed based on platform.
+
+ A subclass may wish to override the following action methods:
+ - get(key, phase) ... get the value from the system at `phase`
+ - set(key, value) ... set the value to the current system
+ """
+
+ def __new__(cls, module):
+ """Return the platform-specific subclass.
+
+ It does not use load_platform_subclass() because it needs to judge based
+ on whether the `timedatectl` command exists and is available.
+
+ Args:
+ module: The AnsibleModule.
+ """
+ if platform.system() == 'Linux':
+ timedatectl = module.get_bin_path('timedatectl')
+ if timedatectl is not None:
+ rc, stdout, stderr = module.run_command(timedatectl)
+ if rc == 0:
+ return super(Timezone, SystemdTimezone).__new__(SystemdTimezone)
+ else:
+ module.warn('timedatectl command was found but not usable: %s. using other method.' % stderr)
+ return super(Timezone, NosystemdTimezone).__new__(NosystemdTimezone)
+ else:
+ return super(Timezone, NosystemdTimezone).__new__(NosystemdTimezone)
+ elif re.match('^joyent_.*Z', platform.version()):
+ # platform.system() returns SunOS, which is too broad. So look at the
+ # platform version instead. However we have to ensure that we're not
+ # running in the global zone where changing the timezone has no effect.
+ zonename_cmd = module.get_bin_path('zonename')
+ if zonename_cmd is not None:
+ (rc, stdout, _) = module.run_command(zonename_cmd)
+ if rc == 0 and stdout.strip() == 'global':
+ module.fail_json(msg='Adjusting timezone is not supported in Global Zone')
+
+ return super(Timezone, SmartOSTimezone).__new__(SmartOSTimezone)
+ elif platform.system() == 'Darwin':
+ return super(Timezone, DarwinTimezone).__new__(DarwinTimezone)
+ elif re.match('^(Free|Net|Open)BSD', platform.platform()):
+ return super(Timezone, BSDTimezone).__new__(BSDTimezone)
+ elif platform.system() == 'AIX':
+ AIXoslevel = int(platform.version() + platform.release())
+ if AIXoslevel >= 61:
+ return super(Timezone, AIXTimezone).__new__(AIXTimezone)
+ else:
+ module.fail_json(msg='AIX os level must be >= 61 for timezone module (Target: %s).' % AIXoslevel)
+ else:
+ # Not supported yet
+ return super(Timezone, Timezone).__new__(Timezone)
+
+ def __init__(self, module):
+ """Initialize of the class.
+
+ Args:
+ module: The AnsibleModule.
+ """
+ super(Timezone, self).__init__()
+ self.msg = []
+ # `self.value` holds the values for each params on each phases.
+ # Initially there's only info of "planned" phase, but the
+ # `self.check()` function will fill out it.
+ self.value = dict()
+ for key in module.argument_spec:
+ value = module.params[key]
+ if value is not None:
+ self.value[key] = dict(planned=value)
+ self.module = module
+
+ def abort(self, msg):
+ """Abort the process with error message.
+
+ This is just the wrapper of module.fail_json().
+
+ Args:
+ msg: The error message.
+ """
+ error_msg = ['Error message:', msg]
+ if len(self.msg) > 0:
+ error_msg.append('Other message(s):')
+ error_msg.extend(self.msg)
+ self.module.fail_json(msg='\n'.join(error_msg))
+
+ def execute(self, *commands, **kwargs):
+ """Execute the shell command.
+
+ This is just the wrapper of module.run_command().
+
+ Args:
+ *commands: The command to execute.
+ It will be concatenated with single space.
+ **kwargs: Only 'log' key is checked.
+ If kwargs['log'] is true, record the command to self.msg.
+
+ Returns:
+ stdout: Standard output of the command.
+ """
+ command = ' '.join(commands)
+ (rc, stdout, stderr) = self.module.run_command(command, check_rc=True)
+ if kwargs.get('log', False):
+ self.msg.append('executed `%s`' % command)
+ return stdout
+
+ def diff(self, phase1='before', phase2='after'):
+ """Calculate the difference between given 2 phases.
+
+ Args:
+ phase1, phase2: The names of phase to compare.
+
+ Returns:
+ diff: The difference of value between phase1 and phase2.
+ This is in the format which can be used with the
+ `--diff` option of ansible-playbook.
+ """
+ diff = {phase1: {}, phase2: {}}
+ for key, value in iteritems(self.value):
+ diff[phase1][key] = value[phase1]
+ diff[phase2][key] = value[phase2]
+ return diff
+
+ def check(self, phase):
+ """Check the state in given phase and set it to `self.value`.
+
+ Args:
+ phase: The name of the phase to check.
+
+ Returns:
+ NO RETURN VALUE
+ """
+ if phase == 'planned':
+ return
+ for key, value in iteritems(self.value):
+ value[phase] = self.get(key, phase)
+
+ def change(self):
+ """Make the changes effect based on `self.value`."""
+ for key, value in iteritems(self.value):
+ if value['before'] != value['planned']:
+ self.set(key, value['planned'])
+
+ # ===========================================
+ # Platform specific methods (must be replaced by subclass).
+
+ def get(self, key, phase):
+ """Get the value for the key at the given phase.
+
+ Called from self.check().
+
+ Args:
+ key: The key to get the value
+ phase: The phase to get the value
+
+ Return:
+ value: The value for the key at the given phase.
+ """
+ self.abort('get(key, phase) is not implemented on target platform')
+
+ def set(self, key, value):
+ """Set the value for the key (of course, for the phase 'after').
+
+ Called from self.change().
+
+ Args:
+ key: Key to set the value
+ value: Value to set
+ """
+ self.abort('set(key, value) is not implemented on target platform')
+
+ def _verify_timezone(self):
+ tz = self.value['name']['planned']
+ tzfile = '/usr/share/zoneinfo/%s' % tz
+ if not os.path.isfile(tzfile):
+ self.abort('given timezone "%s" is not available' % tz)
+ return tzfile
+
+
+class SystemdTimezone(Timezone):
+ """This is a Timezone manipulation class for systemd-powered Linux.
+
+ It uses the `timedatectl` command to check/set all arguments.
+ """
+
+ regexps = dict(
+ hwclock=re.compile(r'^\s*RTC in local TZ\s*:\s*([^\s]+)', re.MULTILINE),
+ name=re.compile(r'^\s*Time ?zone\s*:\s*([^\s]+)', re.MULTILINE)
+ )
+
+ subcmds = dict(
+ hwclock='set-local-rtc',
+ name='set-timezone'
+ )
+
+ def __init__(self, module):
+ super(SystemdTimezone, self).__init__(module)
+ self.timedatectl = module.get_bin_path('timedatectl', required=True)
+ self.status = dict()
+ # Validate given timezone
+ if 'name' in self.value:
+ self._verify_timezone()
+
+ def _get_status(self, phase):
+ if phase not in self.status:
+ self.status[phase] = self.execute(self.timedatectl, 'status')
+ return self.status[phase]
+
+ def get(self, key, phase):
+ status = self._get_status(phase)
+ value = self.regexps[key].search(status).group(1)
+ if key == 'hwclock':
+ # For key='hwclock'; convert yes/no -> local/UTC
+ if self.module.boolean(value):
+ value = 'local'
+ else:
+ value = 'UTC'
+ return value
+
+ def set(self, key, value):
+ # For key='hwclock'; convert UTC/local -> yes/no
+ if key == 'hwclock':
+ if value == 'local':
+ value = 'yes'
+ else:
+ value = 'no'
+ self.execute(self.timedatectl, self.subcmds[key], value, log=True)
+
+
+class NosystemdTimezone(Timezone):
+ """This is a Timezone manipulation class for non systemd-powered Linux.
+
+ For timezone setting, it edits the following file and reflect changes:
+ - /etc/sysconfig/clock ... RHEL/CentOS
+ - /etc/timezone ... Debian/Ubuntu
+ For hwclock setting, it executes `hwclock --systohc` command with the
+ '--utc' or '--localtime' option.
+ """
+
+ conf_files = dict(
+ name=None, # To be set in __init__
+ hwclock=None, # To be set in __init__
+ adjtime='/etc/adjtime'
+ )
+
+ # It's fine if all tree config files don't exist
+ allow_no_file = dict(
+ name=True,
+ hwclock=True,
+ adjtime=True
+ )
+
+ regexps = dict(
+ name=None, # To be set in __init__
+ hwclock=re.compile(r'^UTC\s*=\s*([^\s]+)', re.MULTILINE),
+ adjtime=re.compile(r'^(UTC|LOCAL)$', re.MULTILINE)
+ )
+
+ dist_regexps = dict(
+ SuSE=re.compile(r'^TIMEZONE\s*=\s*"?([^"\s]+)"?', re.MULTILINE),
+ redhat=re.compile(r'^ZONE\s*=\s*"?([^"\s]+)"?', re.MULTILINE)
+ )
+
+ dist_tzline_format = dict(
+ SuSE='TIMEZONE="%s"\n',
+ redhat='ZONE="%s"\n'
+ )
+
+ def __init__(self, module):
+ super(NosystemdTimezone, self).__init__(module)
+ # Validate given timezone
+ if 'name' in self.value:
+ tzfile = self._verify_timezone()
+ # `--remove-destination` is needed if /etc/localtime is a symlink so
+ # that it overwrites it instead of following it.
+ self.update_timezone = ['%s --remove-destination %s /etc/localtime' % (self.module.get_bin_path('cp', required=True), tzfile)]
+ self.update_hwclock = self.module.get_bin_path('hwclock', required=True)
+ # Distribution-specific configurations
+ if self.module.get_bin_path('dpkg-reconfigure') is not None:
+ # Debian/Ubuntu
+ if 'name' in self.value:
+ self.update_timezone = ['%s -sf %s /etc/localtime' % (self.module.get_bin_path('ln', required=True), tzfile),
+ '%s --frontend noninteractive tzdata' % self.module.get_bin_path('dpkg-reconfigure', required=True)]
+ self.conf_files['name'] = '/etc/timezone'
+ self.conf_files['hwclock'] = '/etc/default/rcS'
+ self.regexps['name'] = re.compile(r'^([^\s]+)', re.MULTILINE)
+ self.tzline_format = '%s\n'
+ else:
+ # RHEL/CentOS/SUSE
+ if self.module.get_bin_path('tzdata-update') is not None:
+ # tzdata-update cannot update the timezone if /etc/localtime is
+ # a symlink so we have to use cp to update the time zone which
+ # was set above.
+ if not os.path.islink('/etc/localtime'):
+ self.update_timezone = [self.module.get_bin_path('tzdata-update', required=True)]
+ # else:
+ # self.update_timezone = 'cp --remove-destination ...' <- configured above
+ self.conf_files['name'] = '/etc/sysconfig/clock'
+ self.conf_files['hwclock'] = '/etc/sysconfig/clock'
+ try:
+ f = open(self.conf_files['name'], 'r')
+ except IOError as err:
+ if self._allow_ioerror(err, 'name'):
+ # If the config file doesn't exist detect the distribution and set regexps.
+ distribution = get_distribution()
+ if distribution == 'SuSE':
+ # For SUSE
+ self.regexps['name'] = self.dist_regexps['SuSE']
+ self.tzline_format = self.dist_tzline_format['SuSE']
+ else:
+ # For RHEL/CentOS
+ self.regexps['name'] = self.dist_regexps['redhat']
+ self.tzline_format = self.dist_tzline_format['redhat']
+ else:
+ self.abort('could not read configuration file "%s"' % self.conf_files['name'])
+ else:
+ # The key for timezone might be `ZONE` or `TIMEZONE`
+ # (the former is used in RHEL/CentOS and the latter is used in SUSE linux).
+ # So check the content of /etc/sysconfig/clock and decide which key to use.
+ sysconfig_clock = f.read()
+ f.close()
+ if re.search(r'^TIMEZONE\s*=', sysconfig_clock, re.MULTILINE):
+ # For SUSE
+ self.regexps['name'] = self.dist_regexps['SuSE']
+ self.tzline_format = self.dist_tzline_format['SuSE']
+ else:
+ # For RHEL/CentOS
+ self.regexps['name'] = self.dist_regexps['redhat']
+ self.tzline_format = self.dist_tzline_format['redhat']
+
+ def _allow_ioerror(self, err, key):
+ # In some cases, even if the target file does not exist,
+ # simply creating it may solve the problem.
+ # In such cases, we should continue the configuration rather than aborting.
+ if err.errno != errno.ENOENT:
+ # If the error is not ENOENT ("No such file or directory"),
+ # (e.g., permission error, etc), we should abort.
+ return False
+ return self.allow_no_file.get(key, False)
+
+ def _edit_file(self, filename, regexp, value, key):
+ """Replace the first matched line with given `value`.
+
+ If `regexp` matched more than once, other than the first line will be deleted.
+
+ Args:
+ filename: The name of the file to edit.
+ regexp: The regular expression to search with.
+ value: The line which will be inserted.
+ key: For what key the file is being editted.
+ """
+ # Read the file
+ try:
+ file = open(filename, 'r')
+ except IOError as err:
+ if self._allow_ioerror(err, key):
+ lines = []
+ else:
+ self.abort('tried to configure %s using a file "%s", but could not read it' % (key, filename))
+ else:
+ lines = file.readlines()
+ file.close()
+ # Find the all matched lines
+ matched_indices = []
+ for i, line in enumerate(lines):
+ if regexp.search(line):
+ matched_indices.append(i)
+ if len(matched_indices) > 0:
+ insert_line = matched_indices[0]
+ else:
+ insert_line = 0
+ # Remove all matched lines
+ for i in matched_indices[::-1]:
+ del lines[i]
+ # ...and insert the value
+ lines.insert(insert_line, value)
+ # Write the changes
+ try:
+ file = open(filename, 'w')
+ except IOError:
+ self.abort('tried to configure %s using a file "%s", but could not write to it' % (key, filename))
+ else:
+ file.writelines(lines)
+ file.close()
+ self.msg.append('Added 1 line and deleted %s line(s) on %s' % (len(matched_indices), filename))
+
+ def _get_value_from_config(self, key, phase):
+ filename = self.conf_files[key]
+ try:
+ file = open(filename, mode='r')
+ except IOError as err:
+ if self._allow_ioerror(err, key):
+ if key == 'hwclock':
+ return 'n/a'
+ elif key == 'adjtime':
+ return 'UTC'
+ elif key == 'name':
+ return 'n/a'
+ else:
+ self.abort('tried to configure %s using a file "%s", but could not read it' % (key, filename))
+ else:
+ status = file.read()
+ file.close()
+ try:
+ value = self.regexps[key].search(status).group(1)
+ except AttributeError:
+ if key == 'hwclock':
+ # If we cannot find UTC in the config that's fine.
+ return 'n/a'
+ elif key == 'adjtime':
+ # If we cannot find UTC/LOCAL in /etc/cannot that means UTC
+ # will be used by default.
+ return 'UTC'
+ elif key == 'name':
+ if phase == 'before':
+ # In 'before' phase UTC/LOCAL doesn't need to be set in
+ # the timezone config file, so we ignore this error.
+ return 'n/a'
+ else:
+ self.abort('tried to configure %s using a file "%s", but could not find a valid value in it' % (key, filename))
+ else:
+ if key == 'hwclock':
+ # convert yes/no -> UTC/local
+ if self.module.boolean(value):
+ value = 'UTC'
+ else:
+ value = 'local'
+ elif key == 'adjtime':
+ # convert LOCAL -> local
+ if value != 'UTC':
+ value = value.lower()
+ return value
+
+ def get(self, key, phase):
+ planned = self.value[key]['planned']
+ if key == 'hwclock':
+ value = self._get_value_from_config(key, phase)
+ if value == planned:
+ # If the value in the config file is the same as the 'planned'
+ # value, we need to check /etc/adjtime.
+ value = self._get_value_from_config('adjtime', phase)
+ elif key == 'name':
+ value = self._get_value_from_config(key, phase)
+ if value == planned:
+ # If the planned values is the same as the one in the config file
+ # we need to check if /etc/localtime is also set to the 'planned' zone.
+ if os.path.islink('/etc/localtime'):
+ # If /etc/localtime is a symlink and is not set to the TZ we 'planned'
+ # to set, we need to return the TZ which the symlink points to.
+ if os.path.exists('/etc/localtime'):
+ # We use readlink() because on some distros zone files are symlinks
+ # to other zone files, so it's hard to get which TZ is actually set
+ # if we follow the symlink.
+ path = os.readlink('/etc/localtime')
+ linktz = re.search(r'/usr/share/zoneinfo/(.*)', path, re.MULTILINE)
+ if linktz:
+ valuelink = linktz.group(1)
+ if valuelink != planned:
+ value = valuelink
+ else:
+ # Set current TZ to 'n/a' if the symlink points to a path
+ # which isn't a zone file.
+ value = 'n/a'
+ else:
+ # Set current TZ to 'n/a' if the symlink to the zone file is broken.
+ value = 'n/a'
+ else:
+ # If /etc/localtime is not a symlink best we can do is compare it with
+ # the 'planned' zone info file and return 'n/a' if they are different.
+ try:
+ if not filecmp.cmp('/etc/localtime', '/usr/share/zoneinfo/' + planned):
+ return 'n/a'
+ except Exception:
+ return 'n/a'
+ else:
+ self.abort('unknown parameter "%s"' % key)
+ return value
+
+ def set_timezone(self, value):
+ self._edit_file(filename=self.conf_files['name'],
+ regexp=self.regexps['name'],
+ value=self.tzline_format % value,
+ key='name')
+ for cmd in self.update_timezone:
+ self.execute(cmd)
+
+ def set_hwclock(self, value):
+ if value == 'local':
+ option = '--localtime'
+ utc = 'no'
+ else:
+ option = '--utc'
+ utc = 'yes'
+ if self.conf_files['hwclock'] is not None:
+ self._edit_file(filename=self.conf_files['hwclock'],
+ regexp=self.regexps['hwclock'],
+ value='UTC=%s\n' % utc,
+ key='hwclock')
+ self.execute(self.update_hwclock, '--systohc', option, log=True)
+
+ def set(self, key, value):
+ if key == 'name':
+ self.set_timezone(value)
+ elif key == 'hwclock':
+ self.set_hwclock(value)
+ else:
+ self.abort('unknown parameter "%s"' % key)
+
+
+class SmartOSTimezone(Timezone):
+ """This is a Timezone manipulation class for SmartOS instances.
+
+ It uses the C(sm-set-timezone) utility to set the timezone, and
+ inspects C(/etc/default/init) to determine the current timezone.
+
+ NB: A zone needs to be rebooted in order for the change to be
+ activated.
+ """
+
+ def __init__(self, module):
+ super(SmartOSTimezone, self).__init__(module)
+ self.settimezone = self.module.get_bin_path('sm-set-timezone', required=False)
+ if not self.settimezone:
+ module.fail_json(msg='sm-set-timezone not found. Make sure the smtools package is installed.')
+
+ def get(self, key, phase):
+ """Lookup the current timezone name in `/etc/default/init`. If anything else
+ is requested, or if the TZ field is not set we fail.
+ """
+ if key == 'name':
+ try:
+ f = open('/etc/default/init', 'r')
+ for line in f:
+ m = re.match('^TZ=(.*)$', line.strip())
+ if m:
+ return m.groups()[0]
+ except Exception:
+ self.module.fail_json(msg='Failed to read /etc/default/init')
+ else:
+ self.module.fail_json(msg='%s is not a supported option on target platform' % key)
+
+ def set(self, key, value):
+ """Set the requested timezone through sm-set-timezone, an invalid timezone name
+ will be rejected and we have no further input validation to perform.
+ """
+ if key == 'name':
+ cmd = 'sm-set-timezone %s' % value
+
+ (rc, stdout, stderr) = self.module.run_command(cmd)
+
+ if rc != 0:
+ self.module.fail_json(msg=stderr)
+
+ # sm-set-timezone knows no state and will always set the timezone.
+ # XXX: https://github.com/joyent/smtools/pull/2
+ m = re.match(r'^\* Changed (to)? timezone (to)? (%s).*' % value, stdout.splitlines()[1])
+ if not (m and m.groups()[-1] == value):
+ self.module.fail_json(msg='Failed to set timezone')
+ else:
+ self.module.fail_json(msg='%s is not a supported option on target platform' % key)
+
+
+class DarwinTimezone(Timezone):
+ """This is the timezone implementation for Darwin which, unlike other *BSD
+ implementations, uses the `systemsetup` command on Darwin to check/set
+ the timezone.
+ """
+
+ regexps = dict(
+ name=re.compile(r'^\s*Time ?Zone\s*:\s*([^\s]+)', re.MULTILINE)
+ )
+
+ def __init__(self, module):
+ super(DarwinTimezone, self).__init__(module)
+ self.systemsetup = module.get_bin_path('systemsetup', required=True)
+ self.status = dict()
+ # Validate given timezone
+ if 'name' in self.value:
+ self._verify_timezone()
+
+ def _get_current_timezone(self, phase):
+ """Lookup the current timezone via `systemsetup -gettimezone`."""
+ if phase not in self.status:
+ self.status[phase] = self.execute(self.systemsetup, '-gettimezone')
+ return self.status[phase]
+
+ def _verify_timezone(self):
+ tz = self.value['name']['planned']
+ # Lookup the list of supported timezones via `systemsetup -listtimezones`.
+ # Note: Skip the first line that contains the label 'Time Zones:'
+ out = self.execute(self.systemsetup, '-listtimezones').splitlines()[1:]
+ tz_list = list(map(lambda x: x.strip(), out))
+ if tz not in tz_list:
+ self.abort('given timezone "%s" is not available' % tz)
+ return tz
+
+ def get(self, key, phase):
+ if key == 'name':
+ status = self._get_current_timezone(phase)
+ value = self.regexps[key].search(status).group(1)
+ return value
+ else:
+ self.module.fail_json(msg='%s is not a supported option on target platform' % key)
+
+ def set(self, key, value):
+ if key == 'name':
+ self.execute(self.systemsetup, '-settimezone', value, log=True)
+ else:
+ self.module.fail_json(msg='%s is not a supported option on target platform' % key)
+
+
+class BSDTimezone(Timezone):
+ """This is the timezone implementation for *BSD which works simply through
+ updating the `/etc/localtime` symlink to point to a valid timezone name under
+ `/usr/share/zoneinfo`.
+ """
+
+ def __init__(self, module):
+ super(BSDTimezone, self).__init__(module)
+
+ def __get_timezone(self):
+ zoneinfo_dir = '/usr/share/zoneinfo/'
+ localtime_file = '/etc/localtime'
+
+ # Strategy 1:
+ # If /etc/localtime does not exist, assum the timezone is UTC.
+ if not os.path.exists(localtime_file):
+ self.module.warn('Could not read /etc/localtime. Assuming UTC.')
+ return 'UTC'
+
+ # Strategy 2:
+ # Follow symlink of /etc/localtime
+ zoneinfo_file = localtime_file
+ while not zoneinfo_file.startswith(zoneinfo_dir):
+ try:
+ zoneinfo_file = os.readlink(localtime_file)
+ except OSError:
+ # OSError means "end of symlink chain" or broken link.
+ break
+ else:
+ return zoneinfo_file.replace(zoneinfo_dir, '')
+
+ # Strategy 3:
+ # (If /etc/localtime is not symlinked)
+ # Check all files in /usr/share/zoneinfo and return first non-link match.
+ for dname, _, fnames in sorted(os.walk(zoneinfo_dir)):
+ for fname in sorted(fnames):
+ zoneinfo_file = os.path.join(dname, fname)
+ if not os.path.islink(zoneinfo_file) and filecmp.cmp(zoneinfo_file, localtime_file):
+ return zoneinfo_file.replace(zoneinfo_dir, '')
+
+ # Strategy 4:
+ # As a fall-back, return 'UTC' as default assumption.
+ self.module.warn('Could not identify timezone name from /etc/localtime. Assuming UTC.')
+ return 'UTC'
+
+ def get(self, key, phase):
+ """Lookup the current timezone by resolving `/etc/localtime`."""
+ if key == 'name':
+ return self.__get_timezone()
+ else:
+ self.module.fail_json(msg='%s is not a supported option on target platform' % key)
+
+ def set(self, key, value):
+ if key == 'name':
+ # First determine if the requested timezone is valid by looking in
+ # the zoneinfo directory.
+ zonefile = '/usr/share/zoneinfo/' + value
+ try:
+ if not os.path.isfile(zonefile):
+ self.module.fail_json(msg='%s is not a recognized timezone' % value)
+ except Exception:
+ self.module.fail_json(msg='Failed to stat %s' % zonefile)
+
+ # Now (somewhat) atomically update the symlink by creating a new
+ # symlink and move it into place. Otherwise we have to remove the
+ # original symlink and create the new symlink, however that would
+ # create a race condition in case another process tries to read
+ # /etc/localtime between removal and creation.
+ suffix = "".join([random.choice(string.ascii_letters + string.digits) for x in range(0, 10)])
+ new_localtime = '/etc/localtime.' + suffix
+
+ try:
+ os.symlink(zonefile, new_localtime)
+ os.rename(new_localtime, '/etc/localtime')
+ except Exception:
+ os.remove(new_localtime)
+ self.module.fail_json(msg='Could not update /etc/localtime')
+ else:
+ self.module.fail_json(msg='%s is not a supported option on target platform' % key)
+
+
+class AIXTimezone(Timezone):
+ """This is a Timezone manipulation class for AIX instances.
+
+ It uses the C(chtz) utility to set the timezone, and
+ inspects C(/etc/environment) to determine the current timezone.
+
+ While AIX time zones can be set using two formats (POSIX and
+ Olson) the prefered method is Olson.
+ See the following article for more information:
+ https://developer.ibm.com/articles/au-aix-posix/
+
+ NB: AIX needs to be rebooted in order for the change to be
+ activated.
+ """
+
+ def __init__(self, module):
+ super(AIXTimezone, self).__init__(module)
+ self.settimezone = self.module.get_bin_path('chtz', required=True)
+
+ def __get_timezone(self):
+ """ Return the current value of TZ= in /etc/environment """
+ try:
+ f = open('/etc/environment', 'r')
+ etcenvironment = f.read()
+ f.close()
+ except Exception:
+ self.module.fail_json(msg='Issue reading contents of /etc/environment')
+
+ match = re.search(r'^TZ=(.*)$', etcenvironment, re.MULTILINE)
+ if match:
+ return match.group(1)
+ else:
+ return None
+
+ def get(self, key, phase):
+ """Lookup the current timezone name in `/etc/environment`. If anything else
+ is requested, or if the TZ field is not set we fail.
+ """
+ if key == 'name':
+ return self.__get_timezone()
+ else:
+ self.module.fail_json(msg='%s is not a supported option on target platform' % key)
+
+ def set(self, key, value):
+ """Set the requested timezone through chtz, an invalid timezone name
+ will be rejected and we have no further input validation to perform.
+ """
+ if key == 'name':
+ # chtz seems to always return 0 on AIX 7.2, even for invalid timezone values.
+ # It will only return non-zero if the chtz command itself fails, it does not check for
+ # valid timezones. We need to perform a basic check to confirm that the timezone
+ # definition exists in /usr/share/lib/zoneinfo
+ # This does mean that we can only support Olson for now. The below commented out regex
+ # detects Olson date formats, so in the future we could detect Posix or Olson and
+ # act accordingly.
+
+ # regex_olson = re.compile('^([a-z0-9_\-\+]+\/?)+$', re.IGNORECASE)
+ # if not regex_olson.match(value):
+ # msg = 'Supplied timezone (%s) does not appear to a be valid Olson string' % value
+ # self.module.fail_json(msg=msg)
+
+ # First determine if the requested timezone is valid by looking in the zoneinfo
+ # directory.
+ zonefile = '/usr/share/lib/zoneinfo/' + value
+ try:
+ if not os.path.isfile(zonefile):
+ self.module.fail_json(msg='%s is not a recognized timezone.' % value)
+ except Exception:
+ self.module.fail_json(msg='Failed to check %s.' % zonefile)
+
+ # Now set the TZ using chtz
+ cmd = 'chtz %s' % value
+ (rc, stdout, stderr) = self.module.run_command(cmd)
+
+ if rc != 0:
+ self.module.fail_json(msg=stderr)
+
+ # The best condition check we can do is to check the value of TZ after making the
+ # change.
+ TZ = self.__get_timezone()
+ if TZ != value:
+ msg = 'TZ value does not match post-change (Actual: %s, Expected: %s).' % (TZ, value)
+ self.module.fail_json(msg=msg)
+
+ else:
+ self.module.fail_json(msg='%s is not a supported option on target platform' % key)
+
+
+def main():
+ # Construct 'module' and 'tz'
+ module = AnsibleModule(
+ argument_spec=dict(
+ hwclock=dict(type='str', choices=['local', 'UTC'], aliases=['rtc']),
+ name=dict(type='str'),
+ ),
+ required_one_of=[
+ ['hwclock', 'name']
+ ],
+ supports_check_mode=True,
+ )
+ tz = Timezone(module)
+
+ # Check the current state
+ tz.check(phase='before')
+ if module.check_mode:
+ diff = tz.diff('before', 'planned')
+ # In check mode, 'planned' state is treated as 'after' state
+ diff['after'] = diff.pop('planned')
+ else:
+ # Make change
+ tz.change()
+ # Check the current state
+ tz.check(phase='after')
+ # Examine if the current state matches planned state
+ (after, planned) = tz.diff('after', 'planned').values()
+ if after != planned:
+ tz.abort('still not desired state, though changes have made - '
+ 'planned: %s, after: %s' % (str(planned), str(after)))
+ diff = tz.diff('before', 'after')
+
+ changed = (diff['before'] != diff['after'])
+ if len(tz.msg) > 0:
+ module.exit_json(changed=changed, diff=diff, msg='\n'.join(tz.msg))
+ else:
+ module.exit_json(changed=changed, diff=diff)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/x509_crl.py b/test/support/integration/plugins/modules/x509_crl.py
new file mode 100644
index 00000000..ef601eda
--- /dev/null
+++ b/test/support/integration/plugins/modules/x509_crl.py
@@ -0,0 +1,783 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2019, Felix Fontein <felix@fontein.de>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: x509_crl
+version_added: "2.10"
+short_description: Generate Certificate Revocation Lists (CRLs)
+description:
+ - This module allows one to (re)generate or update Certificate Revocation Lists (CRLs).
+ - Certificates on the revocation list can be either specified via serial number and (optionally) their issuer,
+ or as a path to a certificate file in PEM format.
+requirements:
+ - cryptography >= 1.2
+author:
+ - Felix Fontein (@felixfontein)
+options:
+ state:
+ description:
+ - Whether the CRL file should exist or not, taking action if the state is different from what is stated.
+ type: str
+ default: present
+ choices: [ absent, present ]
+
+ mode:
+ description:
+ - Defines how to process entries of existing CRLs.
+ - If set to C(generate), makes sure that the CRL has the exact set of revoked certificates
+ as specified in I(revoked_certificates).
+ - If set to C(update), makes sure that the CRL contains the revoked certificates from
+ I(revoked_certificates), but can also contain other revoked certificates. If the CRL file
+ already exists, all entries from the existing CRL will also be included in the new CRL.
+ When using C(update), you might be interested in setting I(ignore_timestamps) to C(yes).
+ type: str
+ default: generate
+ choices: [ generate, update ]
+
+ force:
+ description:
+ - Should the CRL be forced to be regenerated.
+ type: bool
+ default: no
+
+ backup:
+ description:
+ - Create a backup file including a timestamp so you can get the original
+ CRL back if you overwrote it with a new one by accident.
+ type: bool
+ default: no
+
+ path:
+ description:
+ - Remote absolute path where the generated CRL file should be created or is already located.
+ type: path
+ required: yes
+
+ privatekey_path:
+ description:
+ - Path to the CA's private key to use when signing the CRL.
+ - Either I(privatekey_path) or I(privatekey_content) must be specified if I(state) is C(present), but not both.
+ type: path
+
+ privatekey_content:
+ description:
+ - The content of the CA's private key to use when signing the CRL.
+ - Either I(privatekey_path) or I(privatekey_content) must be specified if I(state) is C(present), but not both.
+ type: str
+
+ privatekey_passphrase:
+ description:
+ - The passphrase for the I(privatekey_path).
+ - This is required if the private key is password protected.
+ type: str
+
+ issuer:
+ description:
+ - Key/value pairs that will be present in the issuer name field of the CRL.
+ - If you need to specify more than one value with the same key, use a list as value.
+ - Required if I(state) is C(present).
+ type: dict
+
+ last_update:
+ description:
+ - The point in time from which this CRL can be trusted.
+ - Time can be specified either as relative time or as absolute timestamp.
+ - Time will always be interpreted as UTC.
+ - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer
+ + C([w | d | h | m | s]) (e.g. C(+32w1d2h).
+ - Note that if using relative time this module is NOT idempotent, except when
+ I(ignore_timestamps) is set to C(yes).
+ type: str
+ default: "+0s"
+
+ next_update:
+ description:
+ - "The absolute latest point in time by which this I(issuer) is expected to have issued
+ another CRL. Many clients will treat a CRL as expired once I(next_update) occurs."
+ - Time can be specified either as relative time or as absolute timestamp.
+ - Time will always be interpreted as UTC.
+ - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer
+ + C([w | d | h | m | s]) (e.g. C(+32w1d2h).
+ - Note that if using relative time this module is NOT idempotent, except when
+ I(ignore_timestamps) is set to C(yes).
+ - Required if I(state) is C(present).
+ type: str
+
+ digest:
+ description:
+ - Digest algorithm to be used when signing the CRL.
+ type: str
+ default: sha256
+
+ revoked_certificates:
+ description:
+ - List of certificates to be revoked.
+ - Required if I(state) is C(present).
+ type: list
+ elements: dict
+ suboptions:
+ path:
+ description:
+ - Path to a certificate in PEM format.
+ - The serial number and issuer will be extracted from the certificate.
+ - Mutually exclusive with I(content) and I(serial_number). One of these three options
+ must be specified.
+ type: path
+ content:
+ description:
+ - Content of a certificate in PEM format.
+ - The serial number and issuer will be extracted from the certificate.
+ - Mutually exclusive with I(path) and I(serial_number). One of these three options
+ must be specified.
+ type: str
+ serial_number:
+ description:
+ - Serial number of the certificate.
+ - Mutually exclusive with I(path) and I(content). One of these three options must
+ be specified.
+ type: int
+ revocation_date:
+ description:
+ - The point in time the certificate was revoked.
+ - Time can be specified either as relative time or as absolute timestamp.
+ - Time will always be interpreted as UTC.
+ - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer
+ + C([w | d | h | m | s]) (e.g. C(+32w1d2h).
+ - Note that if using relative time this module is NOT idempotent, except when
+ I(ignore_timestamps) is set to C(yes).
+ type: str
+ default: "+0s"
+ issuer:
+ description:
+ - The certificate's issuer.
+ - "Example: C(DNS:ca.example.org)"
+ type: list
+ elements: str
+ issuer_critical:
+ description:
+ - Whether the certificate issuer extension should be critical.
+ type: bool
+ default: no
+ reason:
+ description:
+ - The value for the revocation reason extension.
+ type: str
+ choices:
+ - unspecified
+ - key_compromise
+ - ca_compromise
+ - affiliation_changed
+ - superseded
+ - cessation_of_operation
+ - certificate_hold
+ - privilege_withdrawn
+ - aa_compromise
+ - remove_from_crl
+ reason_critical:
+ description:
+ - Whether the revocation reason extension should be critical.
+ type: bool
+ default: no
+ invalidity_date:
+ description:
+ - The point in time it was known/suspected that the private key was compromised
+ or that the certificate otherwise became invalid.
+ - Time can be specified either as relative time or as absolute timestamp.
+ - Time will always be interpreted as UTC.
+ - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer
+ + C([w | d | h | m | s]) (e.g. C(+32w1d2h).
+ - Note that if using relative time this module is NOT idempotent. This will NOT
+ change when I(ignore_timestamps) is set to C(yes).
+ type: str
+ invalidity_date_critical:
+ description:
+ - Whether the invalidity date extension should be critical.
+ type: bool
+ default: no
+
+ ignore_timestamps:
+ description:
+ - Whether the timestamps I(last_update), I(next_update) and I(revocation_date) (in
+ I(revoked_certificates)) should be ignored for idempotency checks. The timestamp
+ I(invalidity_date) in I(revoked_certificates) will never be ignored.
+ - Use this in combination with relative timestamps for these values to get idempotency.
+ type: bool
+ default: no
+
+ return_content:
+ description:
+ - If set to C(yes), will return the (current or generated) CRL's content as I(crl).
+ type: bool
+ default: no
+
+extends_documentation_fragment:
+ - files
+
+notes:
+ - All ASN.1 TIME values should be specified following the YYYYMMDDHHMMSSZ pattern.
+ - Date specified should be UTC. Minutes and seconds are mandatory.
+'''
+
+EXAMPLES = r'''
+- name: Generate a CRL
+ x509_crl:
+ path: /etc/ssl/my-ca.crl
+ privatekey_path: /etc/ssl/private/my-ca.pem
+ issuer:
+ CN: My CA
+ last_update: "+0s"
+ next_update: "+7d"
+ revoked_certificates:
+ - serial_number: 1234
+ revocation_date: 20190331202428Z
+ issuer:
+ CN: My CA
+ - serial_number: 2345
+ revocation_date: 20191013152910Z
+ reason: affiliation_changed
+ invalidity_date: 20191001000000Z
+ - path: /etc/ssl/crt/revoked-cert.pem
+ revocation_date: 20191010010203Z
+'''
+
+RETURN = r'''
+filename:
+ description: Path to the generated CRL
+ returned: changed or success
+ type: str
+ sample: /path/to/my-ca.crl
+backup_file:
+ description: Name of backup file created.
+ returned: changed and if I(backup) is C(yes)
+ type: str
+ sample: /path/to/my-ca.crl.2019-03-09@11:22~
+privatekey:
+ description: Path to the private CA key
+ returned: changed or success
+ type: str
+ sample: /path/to/my-ca.pem
+issuer:
+ description:
+ - The CRL's issuer.
+ - Note that for repeated values, only the last one will be returned.
+ returned: success
+ type: dict
+ sample: '{"organizationName": "Ansible", "commonName": "ca.example.com"}'
+issuer_ordered:
+ description: The CRL's issuer as an ordered list of tuples.
+ returned: success
+ type: list
+ elements: list
+ sample: '[["organizationName", "Ansible"], ["commonName": "ca.example.com"]]'
+last_update:
+ description: The point in time from which this CRL can be trusted as ASN.1 TIME.
+ returned: success
+ type: str
+ sample: 20190413202428Z
+next_update:
+ description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
+ returned: success
+ type: str
+ sample: 20190413202428Z
+digest:
+ description: The signature algorithm used to sign the CRL.
+ returned: success
+ type: str
+ sample: sha256WithRSAEncryption
+revoked_certificates:
+ description: List of certificates to be revoked.
+ returned: success
+ type: list
+ elements: dict
+ contains:
+ serial_number:
+ description: Serial number of the certificate.
+ type: int
+ sample: 1234
+ revocation_date:
+ description: The point in time the certificate was revoked as ASN.1 TIME.
+ type: str
+ sample: 20190413202428Z
+ issuer:
+ description: The certificate's issuer.
+ type: list
+ elements: str
+ sample: '["DNS:ca.example.org"]'
+ issuer_critical:
+ description: Whether the certificate issuer extension is critical.
+ type: bool
+ sample: no
+ reason:
+ description:
+ - The value for the revocation reason extension.
+ - One of C(unspecified), C(key_compromise), C(ca_compromise), C(affiliation_changed), C(superseded),
+ C(cessation_of_operation), C(certificate_hold), C(privilege_withdrawn), C(aa_compromise), and
+ C(remove_from_crl).
+ type: str
+ sample: key_compromise
+ reason_critical:
+ description: Whether the revocation reason extension is critical.
+ type: bool
+ sample: no
+ invalidity_date:
+ description: |
+ The point in time it was known/suspected that the private key was compromised
+ or that the certificate otherwise became invalid as ASN.1 TIME.
+ type: str
+ sample: 20190413202428Z
+ invalidity_date_critical:
+ description: Whether the invalidity date extension is critical.
+ type: bool
+ sample: no
+crl:
+ description: The (current or generated) CRL's content.
+ returned: if I(state) is C(present) and I(return_content) is C(yes)
+ type: str
+'''
+
+
+import os
+import traceback
+from distutils.version import LooseVersion
+
+from ansible.module_utils import crypto as crypto_utils
+from ansible.module_utils._text import to_native, to_text
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+
+MINIMAL_CRYPTOGRAPHY_VERSION = '1.2'
+
+CRYPTOGRAPHY_IMP_ERR = None
+try:
+ import cryptography
+ from cryptography import x509
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.hazmat.primitives.serialization import Encoding
+ from cryptography.x509 import (
+ CertificateRevocationListBuilder,
+ RevokedCertificateBuilder,
+ NameAttribute,
+ Name,
+ )
+ CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
+except ImportError:
+ CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
+ CRYPTOGRAPHY_FOUND = False
+else:
+ CRYPTOGRAPHY_FOUND = True
+
+
+TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
+
+
+class CRLError(crypto_utils.OpenSSLObjectError):
+ pass
+
+
+class CRL(crypto_utils.OpenSSLObject):
+
+ def __init__(self, module):
+ super(CRL, self).__init__(
+ module.params['path'],
+ module.params['state'],
+ module.params['force'],
+ module.check_mode
+ )
+
+ self.update = module.params['mode'] == 'update'
+ self.ignore_timestamps = module.params['ignore_timestamps']
+ self.return_content = module.params['return_content']
+ self.crl_content = None
+
+ self.privatekey_path = module.params['privatekey_path']
+ self.privatekey_content = module.params['privatekey_content']
+ if self.privatekey_content is not None:
+ self.privatekey_content = self.privatekey_content.encode('utf-8')
+ self.privatekey_passphrase = module.params['privatekey_passphrase']
+
+ self.issuer = crypto_utils.parse_name_field(module.params['issuer'])
+ self.issuer = [(entry[0], entry[1]) for entry in self.issuer if entry[1]]
+
+ self.last_update = crypto_utils.get_relative_time_option(module.params['last_update'], 'last_update')
+ self.next_update = crypto_utils.get_relative_time_option(module.params['next_update'], 'next_update')
+
+ self.digest = crypto_utils.select_message_digest(module.params['digest'])
+ if self.digest is None:
+ raise CRLError('The digest "{0}" is not supported'.format(module.params['digest']))
+
+ self.revoked_certificates = []
+ for i, rc in enumerate(module.params['revoked_certificates']):
+ result = {
+ 'serial_number': None,
+ 'revocation_date': None,
+ 'issuer': None,
+ 'issuer_critical': False,
+ 'reason': None,
+ 'reason_critical': False,
+ 'invalidity_date': None,
+ 'invalidity_date_critical': False,
+ }
+ path_prefix = 'revoked_certificates[{0}].'.format(i)
+ if rc['path'] is not None or rc['content'] is not None:
+ # Load certificate from file or content
+ try:
+ if rc['content'] is not None:
+ rc['content'] = rc['content'].encode('utf-8')
+ cert = crypto_utils.load_certificate(rc['path'], content=rc['content'], backend='cryptography')
+ try:
+ result['serial_number'] = cert.serial_number
+ except AttributeError:
+ # The property was called "serial" before cryptography 1.4
+ result['serial_number'] = cert.serial
+ except crypto_utils.OpenSSLObjectError as e:
+ if rc['content'] is not None:
+ module.fail_json(
+ msg='Cannot parse certificate from {0}content: {1}'.format(path_prefix, to_native(e))
+ )
+ else:
+ module.fail_json(
+ msg='Cannot read certificate "{1}" from {0}path: {2}'.format(path_prefix, rc['path'], to_native(e))
+ )
+ else:
+ # Specify serial_number (and potentially issuer) directly
+ result['serial_number'] = rc['serial_number']
+ # All other options
+ if rc['issuer']:
+ result['issuer'] = [crypto_utils.cryptography_get_name(issuer) for issuer in rc['issuer']]
+ result['issuer_critical'] = rc['issuer_critical']
+ result['revocation_date'] = crypto_utils.get_relative_time_option(
+ rc['revocation_date'],
+ path_prefix + 'revocation_date'
+ )
+ if rc['reason']:
+ result['reason'] = crypto_utils.REVOCATION_REASON_MAP[rc['reason']]
+ result['reason_critical'] = rc['reason_critical']
+ if rc['invalidity_date']:
+ result['invalidity_date'] = crypto_utils.get_relative_time_option(
+ rc['invalidity_date'],
+ path_prefix + 'invalidity_date'
+ )
+ result['invalidity_date_critical'] = rc['invalidity_date_critical']
+ self.revoked_certificates.append(result)
+
+ self.module = module
+
+ self.backup = module.params['backup']
+ self.backup_file = None
+
+ try:
+ self.privatekey = crypto_utils.load_privatekey(
+ path=self.privatekey_path,
+ content=self.privatekey_content,
+ passphrase=self.privatekey_passphrase,
+ backend='cryptography'
+ )
+ except crypto_utils.OpenSSLBadPassphraseError as exc:
+ raise CRLError(exc)
+
+ self.crl = None
+ try:
+ with open(self.path, 'rb') as f:
+ data = f.read()
+ self.crl = x509.load_pem_x509_crl(data, default_backend())
+ if self.return_content:
+ self.crl_content = data
+ except Exception as dummy:
+ self.crl_content = None
+
+ def remove(self):
+ if self.backup:
+ self.backup_file = self.module.backup_local(self.path)
+ super(CRL, self).remove(self.module)
+
+ def _compress_entry(self, entry):
+ if self.ignore_timestamps:
+ # Throw out revocation_date
+ return (
+ entry['serial_number'],
+ tuple(entry['issuer']) if entry['issuer'] is not None else None,
+ entry['issuer_critical'],
+ entry['reason'],
+ entry['reason_critical'],
+ entry['invalidity_date'],
+ entry['invalidity_date_critical'],
+ )
+ else:
+ return (
+ entry['serial_number'],
+ entry['revocation_date'],
+ tuple(entry['issuer']) if entry['issuer'] is not None else None,
+ entry['issuer_critical'],
+ entry['reason'],
+ entry['reason_critical'],
+ entry['invalidity_date'],
+ entry['invalidity_date_critical'],
+ )
+
+ def check(self, perms_required=True):
+ """Ensure the resource is in its desired state."""
+
+ state_and_perms = super(CRL, self).check(self.module, perms_required)
+
+ if not state_and_perms:
+ return False
+
+ if self.crl is None:
+ return False
+
+ if self.last_update != self.crl.last_update and not self.ignore_timestamps:
+ return False
+ if self.next_update != self.crl.next_update and not self.ignore_timestamps:
+ return False
+ if self.digest.name != self.crl.signature_hash_algorithm.name:
+ return False
+
+ want_issuer = [(crypto_utils.cryptography_name_to_oid(entry[0]), entry[1]) for entry in self.issuer]
+ if want_issuer != [(sub.oid, sub.value) for sub in self.crl.issuer]:
+ return False
+
+ old_entries = [self._compress_entry(crypto_utils.cryptography_decode_revoked_certificate(cert)) for cert in self.crl]
+ new_entries = [self._compress_entry(cert) for cert in self.revoked_certificates]
+ if self.update:
+ # We don't simply use a set so that duplicate entries are treated correctly
+ for entry in new_entries:
+ try:
+ old_entries.remove(entry)
+ except ValueError:
+ return False
+ else:
+ if old_entries != new_entries:
+ return False
+
+ return True
+
+ def _generate_crl(self):
+ backend = default_backend()
+ crl = CertificateRevocationListBuilder()
+
+ try:
+ crl = crl.issuer_name(Name([
+ NameAttribute(crypto_utils.cryptography_name_to_oid(entry[0]), to_text(entry[1]))
+ for entry in self.issuer
+ ]))
+ except ValueError as e:
+ raise CRLError(e)
+
+ crl = crl.last_update(self.last_update)
+ crl = crl.next_update(self.next_update)
+
+ if self.update and self.crl:
+ new_entries = set([self._compress_entry(entry) for entry in self.revoked_certificates])
+ for entry in self.crl:
+ decoded_entry = self._compress_entry(crypto_utils.cryptography_decode_revoked_certificate(entry))
+ if decoded_entry not in new_entries:
+ crl = crl.add_revoked_certificate(entry)
+ for entry in self.revoked_certificates:
+ revoked_cert = RevokedCertificateBuilder()
+ revoked_cert = revoked_cert.serial_number(entry['serial_number'])
+ revoked_cert = revoked_cert.revocation_date(entry['revocation_date'])
+ if entry['issuer'] is not None:
+ revoked_cert = revoked_cert.add_extension(
+ x509.CertificateIssuer([
+ crypto_utils.cryptography_get_name(name) for name in self.entry['issuer']
+ ]),
+ entry['issuer_critical']
+ )
+ if entry['reason'] is not None:
+ revoked_cert = revoked_cert.add_extension(
+ x509.CRLReason(entry['reason']),
+ entry['reason_critical']
+ )
+ if entry['invalidity_date'] is not None:
+ revoked_cert = revoked_cert.add_extension(
+ x509.InvalidityDate(entry['invalidity_date']),
+ entry['invalidity_date_critical']
+ )
+ crl = crl.add_revoked_certificate(revoked_cert.build(backend))
+
+ self.crl = crl.sign(self.privatekey, self.digest, backend=backend)
+ return self.crl.public_bytes(Encoding.PEM)
+
+ def generate(self):
+ if not self.check(perms_required=False) or self.force:
+ result = self._generate_crl()
+ if self.return_content:
+ self.crl_content = result
+ if self.backup:
+ self.backup_file = self.module.backup_local(self.path)
+ crypto_utils.write_file(self.module, result)
+ self.changed = True
+
+ file_args = self.module.load_file_common_arguments(self.module.params)
+ if self.module.set_fs_attributes_if_different(file_args, False):
+ self.changed = True
+
+ def _dump_revoked(self, entry):
+ return {
+ 'serial_number': entry['serial_number'],
+ 'revocation_date': entry['revocation_date'].strftime(TIMESTAMP_FORMAT),
+ 'issuer':
+ [crypto_utils.cryptography_decode_name(issuer) for issuer in entry['issuer']]
+ if entry['issuer'] is not None else None,
+ 'issuer_critical': entry['issuer_critical'],
+ 'reason': crypto_utils.REVOCATION_REASON_MAP_INVERSE.get(entry['reason']) if entry['reason'] is not None else None,
+ 'reason_critical': entry['reason_critical'],
+ 'invalidity_date':
+ entry['invalidity_date'].strftime(TIMESTAMP_FORMAT)
+ if entry['invalidity_date'] is not None else None,
+ 'invalidity_date_critical': entry['invalidity_date_critical'],
+ }
+
+ def dump(self, check_mode=False):
+ result = {
+ 'changed': self.changed,
+ 'filename': self.path,
+ 'privatekey': self.privatekey_path,
+ 'last_update': None,
+ 'next_update': None,
+ 'digest': None,
+ 'issuer_ordered': None,
+ 'issuer': None,
+ 'revoked_certificates': [],
+ }
+ if self.backup_file:
+ result['backup_file'] = self.backup_file
+
+ if check_mode:
+ result['last_update'] = self.last_update.strftime(TIMESTAMP_FORMAT)
+ result['next_update'] = self.next_update.strftime(TIMESTAMP_FORMAT)
+ # result['digest'] = crypto_utils.cryptography_oid_to_name(self.crl.signature_algorithm_oid)
+ result['digest'] = self.module.params['digest']
+ result['issuer_ordered'] = self.issuer
+ result['issuer'] = {}
+ for k, v in self.issuer:
+ result['issuer'][k] = v
+ result['revoked_certificates'] = []
+ for entry in self.revoked_certificates:
+ result['revoked_certificates'].append(self._dump_revoked(entry))
+ elif self.crl:
+ result['last_update'] = self.crl.last_update.strftime(TIMESTAMP_FORMAT)
+ result['next_update'] = self.crl.next_update.strftime(TIMESTAMP_FORMAT)
+ try:
+ result['digest'] = crypto_utils.cryptography_oid_to_name(self.crl.signature_algorithm_oid)
+ except AttributeError:
+ # Older cryptography versions don't have signature_algorithm_oid yet
+ dotted = crypto_utils._obj2txt(
+ self.crl._backend._lib,
+ self.crl._backend._ffi,
+ self.crl._x509_crl.sig_alg.algorithm
+ )
+ oid = x509.oid.ObjectIdentifier(dotted)
+ result['digest'] = crypto_utils.cryptography_oid_to_name(oid)
+ issuer = []
+ for attribute in self.crl.issuer:
+ issuer.append([crypto_utils.cryptography_oid_to_name(attribute.oid), attribute.value])
+ result['issuer_ordered'] = issuer
+ result['issuer'] = {}
+ for k, v in issuer:
+ result['issuer'][k] = v
+ result['revoked_certificates'] = []
+ for cert in self.crl:
+ entry = crypto_utils.cryptography_decode_revoked_certificate(cert)
+ result['revoked_certificates'].append(self._dump_revoked(entry))
+
+ if self.return_content:
+ result['crl'] = self.crl_content
+
+ return result
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ state=dict(type='str', default='present', choices=['present', 'absent']),
+ mode=dict(type='str', default='generate', choices=['generate', 'update']),
+ force=dict(type='bool', default=False),
+ backup=dict(type='bool', default=False),
+ path=dict(type='path', required=True),
+ privatekey_path=dict(type='path'),
+ privatekey_content=dict(type='str'),
+ privatekey_passphrase=dict(type='str', no_log=True),
+ issuer=dict(type='dict'),
+ last_update=dict(type='str', default='+0s'),
+ next_update=dict(type='str'),
+ digest=dict(type='str', default='sha256'),
+ ignore_timestamps=dict(type='bool', default=False),
+ return_content=dict(type='bool', default=False),
+ revoked_certificates=dict(
+ type='list',
+ elements='dict',
+ options=dict(
+ path=dict(type='path'),
+ content=dict(type='str'),
+ serial_number=dict(type='int'),
+ revocation_date=dict(type='str', default='+0s'),
+ issuer=dict(type='list', elements='str'),
+ issuer_critical=dict(type='bool', default=False),
+ reason=dict(
+ type='str',
+ choices=[
+ 'unspecified', 'key_compromise', 'ca_compromise', 'affiliation_changed',
+ 'superseded', 'cessation_of_operation', 'certificate_hold',
+ 'privilege_withdrawn', 'aa_compromise', 'remove_from_crl'
+ ]
+ ),
+ reason_critical=dict(type='bool', default=False),
+ invalidity_date=dict(type='str'),
+ invalidity_date_critical=dict(type='bool', default=False),
+ ),
+ required_one_of=[['path', 'content', 'serial_number']],
+ mutually_exclusive=[['path', 'content', 'serial_number']],
+ ),
+ ),
+ required_if=[
+ ('state', 'present', ['privatekey_path', 'privatekey_content'], True),
+ ('state', 'present', ['issuer', 'next_update', 'revoked_certificates'], False),
+ ],
+ mutually_exclusive=(
+ ['privatekey_path', 'privatekey_content'],
+ ),
+ supports_check_mode=True,
+ add_file_common_args=True,
+ )
+
+ if not CRYPTOGRAPHY_FOUND:
+ module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
+ exception=CRYPTOGRAPHY_IMP_ERR)
+
+ try:
+ crl = CRL(module)
+
+ if module.params['state'] == 'present':
+ if module.check_mode:
+ result = crl.dump(check_mode=True)
+ result['changed'] = module.params['force'] or not crl.check()
+ module.exit_json(**result)
+
+ crl.generate()
+ else:
+ if module.check_mode:
+ result = crl.dump(check_mode=True)
+ result['changed'] = os.path.exists(module.params['path'])
+ module.exit_json(**result)
+
+ crl.remove()
+
+ result = crl.dump()
+ module.exit_json(**result)
+ except crypto_utils.OpenSSLObjectError as exc:
+ module.fail_json(msg=to_native(exc))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/integration/plugins/modules/x509_crl_info.py b/test/support/integration/plugins/modules/x509_crl_info.py
new file mode 100644
index 00000000..b61db26f
--- /dev/null
+++ b/test/support/integration/plugins/modules/x509_crl_info.py
@@ -0,0 +1,281 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2020, Felix Fontein <felix@fontein.de>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: x509_crl_info
+version_added: "2.10"
+short_description: Retrieve information on Certificate Revocation Lists (CRLs)
+description:
+ - This module allows one to retrieve information on Certificate Revocation Lists (CRLs).
+requirements:
+ - cryptography >= 1.2
+author:
+ - Felix Fontein (@felixfontein)
+options:
+ path:
+ description:
+ - Remote absolute path where the generated CRL file should be created or is already located.
+ - Either I(path) or I(content) must be specified, but not both.
+ type: path
+ content:
+ description:
+ - Content of the X.509 certificate in PEM format.
+ - Either I(path) or I(content) must be specified, but not both.
+ type: str
+
+notes:
+ - All timestamp values are provided in ASN.1 TIME format, i.e. following the C(YYYYMMDDHHMMSSZ) pattern.
+ They are all in UTC.
+seealso:
+ - module: x509_crl
+'''
+
+EXAMPLES = r'''
+- name: Get information on CRL
+ x509_crl_info:
+ path: /etc/ssl/my-ca.crl
+ register: result
+
+- debug:
+ msg: "{{ result }}"
+'''
+
+RETURN = r'''
+issuer:
+ description:
+ - The CRL's issuer.
+ - Note that for repeated values, only the last one will be returned.
+ returned: success
+ type: dict
+ sample: '{"organizationName": "Ansible", "commonName": "ca.example.com"}'
+issuer_ordered:
+ description: The CRL's issuer as an ordered list of tuples.
+ returned: success
+ type: list
+ elements: list
+ sample: '[["organizationName", "Ansible"], ["commonName": "ca.example.com"]]'
+last_update:
+ description: The point in time from which this CRL can be trusted as ASN.1 TIME.
+ returned: success
+ type: str
+ sample: 20190413202428Z
+next_update:
+ description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
+ returned: success
+ type: str
+ sample: 20190413202428Z
+digest:
+ description: The signature algorithm used to sign the CRL.
+ returned: success
+ type: str
+ sample: sha256WithRSAEncryption
+revoked_certificates:
+ description: List of certificates to be revoked.
+ returned: success
+ type: list
+ elements: dict
+ contains:
+ serial_number:
+ description: Serial number of the certificate.
+ type: int
+ sample: 1234
+ revocation_date:
+ description: The point in time the certificate was revoked as ASN.1 TIME.
+ type: str
+ sample: 20190413202428Z
+ issuer:
+ description: The certificate's issuer.
+ type: list
+ elements: str
+ sample: '["DNS:ca.example.org"]'
+ issuer_critical:
+ description: Whether the certificate issuer extension is critical.
+ type: bool
+ sample: no
+ reason:
+ description:
+ - The value for the revocation reason extension.
+ - One of C(unspecified), C(key_compromise), C(ca_compromise), C(affiliation_changed), C(superseded),
+ C(cessation_of_operation), C(certificate_hold), C(privilege_withdrawn), C(aa_compromise), and
+ C(remove_from_crl).
+ type: str
+ sample: key_compromise
+ reason_critical:
+ description: Whether the revocation reason extension is critical.
+ type: bool
+ sample: no
+ invalidity_date:
+ description: |
+ The point in time it was known/suspected that the private key was compromised
+ or that the certificate otherwise became invalid as ASN.1 TIME.
+ type: str
+ sample: 20190413202428Z
+ invalidity_date_critical:
+ description: Whether the invalidity date extension is critical.
+ type: bool
+ sample: no
+'''
+
+
+import traceback
+from distutils.version import LooseVersion
+
+from ansible.module_utils import crypto as crypto_utils
+from ansible.module_utils._text import to_native
+from ansible.module_utils.basic import AnsibleModule, missing_required_lib
+
+MINIMAL_CRYPTOGRAPHY_VERSION = '1.2'
+
+CRYPTOGRAPHY_IMP_ERR = None
+try:
+ import cryptography
+ from cryptography import x509
+ from cryptography.hazmat.backends import default_backend
+ CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
+except ImportError:
+ CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
+ CRYPTOGRAPHY_FOUND = False
+else:
+ CRYPTOGRAPHY_FOUND = True
+
+
+TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
+
+
+class CRLError(crypto_utils.OpenSSLObjectError):
+ pass
+
+
+class CRLInfo(crypto_utils.OpenSSLObject):
+ """The main module implementation."""
+
+ def __init__(self, module):
+ super(CRLInfo, self).__init__(
+ module.params['path'] or '',
+ 'present',
+ False,
+ module.check_mode
+ )
+
+ self.content = module.params['content']
+
+ self.module = module
+
+ self.crl = None
+ if self.content is None:
+ try:
+ with open(self.path, 'rb') as f:
+ data = f.read()
+ except Exception as e:
+ self.module.fail_json(msg='Error while reading CRL file from disk: {0}'.format(e))
+ else:
+ data = self.content.encode('utf-8')
+
+ try:
+ self.crl = x509.load_pem_x509_crl(data, default_backend())
+ except Exception as e:
+ self.module.fail_json(msg='Error while decoding CRL: {0}'.format(e))
+
+ def _dump_revoked(self, entry):
+ return {
+ 'serial_number': entry['serial_number'],
+ 'revocation_date': entry['revocation_date'].strftime(TIMESTAMP_FORMAT),
+ 'issuer':
+ [crypto_utils.cryptography_decode_name(issuer) for issuer in entry['issuer']]
+ if entry['issuer'] is not None else None,
+ 'issuer_critical': entry['issuer_critical'],
+ 'reason': crypto_utils.REVOCATION_REASON_MAP_INVERSE.get(entry['reason']) if entry['reason'] is not None else None,
+ 'reason_critical': entry['reason_critical'],
+ 'invalidity_date':
+ entry['invalidity_date'].strftime(TIMESTAMP_FORMAT)
+ if entry['invalidity_date'] is not None else None,
+ 'invalidity_date_critical': entry['invalidity_date_critical'],
+ }
+
+ def get_info(self):
+ result = {
+ 'changed': False,
+ 'last_update': None,
+ 'next_update': None,
+ 'digest': None,
+ 'issuer_ordered': None,
+ 'issuer': None,
+ 'revoked_certificates': [],
+ }
+
+ result['last_update'] = self.crl.last_update.strftime(TIMESTAMP_FORMAT)
+ result['next_update'] = self.crl.next_update.strftime(TIMESTAMP_FORMAT)
+ try:
+ result['digest'] = crypto_utils.cryptography_oid_to_name(self.crl.signature_algorithm_oid)
+ except AttributeError:
+ # Older cryptography versions don't have signature_algorithm_oid yet
+ dotted = crypto_utils._obj2txt(
+ self.crl._backend._lib,
+ self.crl._backend._ffi,
+ self.crl._x509_crl.sig_alg.algorithm
+ )
+ oid = x509.oid.ObjectIdentifier(dotted)
+ result['digest'] = crypto_utils.cryptography_oid_to_name(oid)
+ issuer = []
+ for attribute in self.crl.issuer:
+ issuer.append([crypto_utils.cryptography_oid_to_name(attribute.oid), attribute.value])
+ result['issuer_ordered'] = issuer
+ result['issuer'] = {}
+ for k, v in issuer:
+ result['issuer'][k] = v
+ result['revoked_certificates'] = []
+ for cert in self.crl:
+ entry = crypto_utils.cryptography_decode_revoked_certificate(cert)
+ result['revoked_certificates'].append(self._dump_revoked(entry))
+
+ return result
+
+ def generate(self):
+ # Empty method because crypto_utils.OpenSSLObject wants this
+ pass
+
+ def dump(self):
+ # Empty method because crypto_utils.OpenSSLObject wants this
+ pass
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ path=dict(type='path'),
+ content=dict(type='str'),
+ ),
+ required_one_of=(
+ ['path', 'content'],
+ ),
+ mutually_exclusive=(
+ ['path', 'content'],
+ ),
+ supports_check_mode=True,
+ )
+
+ if not CRYPTOGRAPHY_FOUND:
+ module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
+ exception=CRYPTOGRAPHY_IMP_ERR)
+
+ try:
+ crl = CRLInfo(module)
+ result = crl.get_info()
+ module.exit_json(**result)
+ except crypto_utils.OpenSSLObjectError as e:
+ module.fail_json(msg=to_native(e))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/integration/plugins/modules/xml.py b/test/support/integration/plugins/modules/xml.py
new file mode 100644
index 00000000..b5b35a38
--- /dev/null
+++ b/test/support/integration/plugins/modules/xml.py
@@ -0,0 +1,966 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2014, Red Hat, Inc.
+# Copyright: (c) 2014, Tim Bielawa <tbielawa@redhat.com>
+# Copyright: (c) 2014, Magnus Hedemark <mhedemar@redhat.com>
+# Copyright: (c) 2017, Dag Wieers <dag@wieers.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: xml
+short_description: Manage bits and pieces of XML files or strings
+description:
+- A CRUD-like interface to managing bits of XML files.
+version_added: '2.4'
+options:
+ path:
+ description:
+ - Path to the file to operate on.
+ - This file must exist ahead of time.
+ - This parameter is required, unless C(xmlstring) is given.
+ type: path
+ required: yes
+ aliases: [ dest, file ]
+ xmlstring:
+ description:
+ - A string containing XML on which to operate.
+ - This parameter is required, unless C(path) is given.
+ type: str
+ required: yes
+ xpath:
+ description:
+ - A valid XPath expression describing the item(s) you want to manipulate.
+ - Operates on the document root, C(/), by default.
+ type: str
+ namespaces:
+ description:
+ - The namespace C(prefix:uri) mapping for the XPath expression.
+ - Needs to be a C(dict), not a C(list) of items.
+ type: dict
+ state:
+ description:
+ - Set or remove an xpath selection (node(s), attribute(s)).
+ type: str
+ choices: [ absent, present ]
+ default: present
+ aliases: [ ensure ]
+ attribute:
+ description:
+ - The attribute to select when using parameter C(value).
+ - This is a string, not prepended with C(@).
+ type: raw
+ value:
+ description:
+ - Desired state of the selected attribute.
+ - Either a string, or to unset a value, the Python C(None) keyword (YAML Equivalent, C(null)).
+ - Elements default to no value (but present).
+ - Attributes default to an empty string.
+ type: raw
+ add_children:
+ description:
+ - Add additional child-element(s) to a selected element for a given C(xpath).
+ - Child elements must be given in a list and each item may be either a string
+ (eg. C(children=ansible) to add an empty C(<ansible/>) child element),
+ or a hash where the key is an element name and the value is the element value.
+ - This parameter requires C(xpath) to be set.
+ type: list
+ set_children:
+ description:
+ - Set the child-element(s) of a selected element for a given C(xpath).
+ - Removes any existing children.
+ - Child elements must be specified as in C(add_children).
+ - This parameter requires C(xpath) to be set.
+ type: list
+ count:
+ description:
+ - Search for a given C(xpath) and provide the count of any matches.
+ - This parameter requires C(xpath) to be set.
+ type: bool
+ default: no
+ print_match:
+ description:
+ - Search for a given C(xpath) and print out any matches.
+ - This parameter requires C(xpath) to be set.
+ type: bool
+ default: no
+ pretty_print:
+ description:
+ - Pretty print XML output.
+ type: bool
+ default: no
+ content:
+ description:
+ - Search for a given C(xpath) and get content.
+ - This parameter requires C(xpath) to be set.
+ type: str
+ choices: [ attribute, text ]
+ input_type:
+ description:
+ - Type of input for C(add_children) and C(set_children).
+ type: str
+ choices: [ xml, yaml ]
+ default: yaml
+ backup:
+ description:
+ - Create a backup file including the timestamp information so you can get
+ the original file back if you somehow clobbered it incorrectly.
+ type: bool
+ default: no
+ strip_cdata_tags:
+ description:
+ - Remove CDATA tags surrounding text values.
+ - Note that this might break your XML file if text values contain characters that could be interpreted as XML.
+ type: bool
+ default: no
+ version_added: '2.7'
+ insertbefore:
+ description:
+ - Add additional child-element(s) before the first selected element for a given C(xpath).
+ - Child elements must be given in a list and each item may be either a string
+ (eg. C(children=ansible) to add an empty C(<ansible/>) child element),
+ or a hash where the key is an element name and the value is the element value.
+ - This parameter requires C(xpath) to be set.
+ type: bool
+ default: no
+ version_added: '2.8'
+ insertafter:
+ description:
+ - Add additional child-element(s) after the last selected element for a given C(xpath).
+ - Child elements must be given in a list and each item may be either a string
+ (eg. C(children=ansible) to add an empty C(<ansible/>) child element),
+ or a hash where the key is an element name and the value is the element value.
+ - This parameter requires C(xpath) to be set.
+ type: bool
+ default: no
+ version_added: '2.8'
+requirements:
+- lxml >= 2.3.0
+notes:
+- Use the C(--check) and C(--diff) options when testing your expressions.
+- The diff output is automatically pretty-printed, so may not reflect the actual file content, only the file structure.
+- This module does not handle complicated xpath expressions, so limit xpath selectors to simple expressions.
+- Beware that in case your XML elements are namespaced, you need to use the C(namespaces) parameter, see the examples.
+- Namespaces prefix should be used for all children of an element where namespace is defined, unless another namespace is defined for them.
+seealso:
+- name: Xml module development community wiki
+ description: More information related to the development of this xml module.
+ link: https://github.com/ansible/community/wiki/Module:-xml
+- name: Introduction to XPath
+ description: A brief tutorial on XPath (w3schools.com).
+ link: https://www.w3schools.com/xml/xpath_intro.asp
+- name: XPath Reference document
+ description: The reference documentation on XSLT/XPath (developer.mozilla.org).
+ link: https://developer.mozilla.org/en-US/docs/Web/XPath
+author:
+- Tim Bielawa (@tbielawa)
+- Magnus Hedemark (@magnus919)
+- Dag Wieers (@dagwieers)
+'''
+
+EXAMPLES = r'''
+# Consider the following XML file:
+#
+# <business type="bar">
+# <name>Tasty Beverage Co.</name>
+# <beers>
+# <beer>Rochefort 10</beer>
+# <beer>St. Bernardus Abbot 12</beer>
+# <beer>Schlitz</beer>
+# </beers>
+# <rating subjective="true">10</rating>
+# <website>
+# <mobilefriendly/>
+# <address>http://tastybeverageco.com</address>
+# </website>
+# </business>
+
+- name: Remove the 'subjective' attribute of the 'rating' element
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/rating/@subjective
+ state: absent
+
+- name: Set the rating to '11'
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/rating
+ value: 11
+
+# Retrieve and display the number of nodes
+- name: Get count of 'beers' nodes
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/beers/beer
+ count: yes
+ register: hits
+
+- debug:
+ var: hits.count
+
+# Example where parent XML nodes are created automatically
+- name: Add a 'phonenumber' element to the 'business' element
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/phonenumber
+ value: 555-555-1234
+
+- name: Add several more beers to the 'beers' element
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/beers
+ add_children:
+ - beer: Old Rasputin
+ - beer: Old Motor Oil
+ - beer: Old Curmudgeon
+
+- name: Add several more beers to the 'beers' element and add them before the 'Rochefort 10' element
+ xml:
+ path: /foo/bar.xml
+ xpath: '/business/beers/beer[text()="Rochefort 10"]'
+ insertbefore: yes
+ add_children:
+ - beer: Old Rasputin
+ - beer: Old Motor Oil
+ - beer: Old Curmudgeon
+
+# NOTE: The 'state' defaults to 'present' and 'value' defaults to 'null' for elements
+- name: Add a 'validxhtml' element to the 'website' element
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/website/validxhtml
+
+- name: Add an empty 'validatedon' attribute to the 'validxhtml' element
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/website/validxhtml/@validatedon
+
+- name: Add or modify an attribute, add element if needed
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/website/validxhtml
+ attribute: validatedon
+ value: 1976-08-05
+
+# How to read an attribute value and access it in Ansible
+- name: Read an element's attribute values
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/website/validxhtml
+ content: attribute
+ register: xmlresp
+
+- name: Show an attribute value
+ debug:
+ var: xmlresp.matches[0].validxhtml.validatedon
+
+- name: Remove all children from the 'website' element (option 1)
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/website/*
+ state: absent
+
+- name: Remove all children from the 'website' element (option 2)
+ xml:
+ path: /foo/bar.xml
+ xpath: /business/website
+ children: []
+
+# In case of namespaces, like in below XML, they have to be explicitly stated.
+#
+# <foo xmlns="http://x.test" xmlns:attr="http://z.test">
+# <bar>
+# <baz xmlns="http://y.test" attr:my_namespaced_attribute="true" />
+# </bar>
+# </foo>
+
+# NOTE: There is the prefix 'x' in front of the 'bar' element, too.
+- name: Set namespaced '/x:foo/x:bar/y:baz/@z:my_namespaced_attribute' to 'false'
+ xml:
+ path: foo.xml
+ xpath: /x:foo/x:bar/y:baz
+ namespaces:
+ x: http://x.test
+ y: http://y.test
+ z: http://z.test
+ attribute: z:my_namespaced_attribute
+ value: 'false'
+'''
+
+RETURN = r'''
+actions:
+ description: A dictionary with the original xpath, namespaces and state.
+ type: dict
+ returned: success
+ sample: {xpath: xpath, namespaces: [namespace1, namespace2], state=present}
+backup_file:
+ description: The name of the backup file that was created
+ type: str
+ returned: when backup=yes
+ sample: /path/to/file.xml.1942.2017-08-24@14:16:01~
+count:
+ description: The count of xpath matches.
+ type: int
+ returned: when parameter 'count' is set
+ sample: 2
+matches:
+ description: The xpath matches found.
+ type: list
+ returned: when parameter 'print_match' is set
+msg:
+ description: A message related to the performed action(s).
+ type: str
+ returned: always
+xmlstring:
+ description: An XML string of the resulting output.
+ type: str
+ returned: when parameter 'xmlstring' is set
+'''
+
+import copy
+import json
+import os
+import re
+import traceback
+
+from distutils.version import LooseVersion
+from io import BytesIO
+
+LXML_IMP_ERR = None
+try:
+ from lxml import etree, objectify
+ HAS_LXML = True
+except ImportError:
+ LXML_IMP_ERR = traceback.format_exc()
+ HAS_LXML = False
+
+from ansible.module_utils.basic import AnsibleModule, json_dict_bytes_to_unicode, missing_required_lib
+from ansible.module_utils.six import iteritems, string_types
+from ansible.module_utils._text import to_bytes, to_native
+from ansible.module_utils.common._collections_compat import MutableMapping
+
+_IDENT = r"[a-zA-Z-][a-zA-Z0-9_\-\.]*"
+_NSIDENT = _IDENT + "|" + _IDENT + ":" + _IDENT
+# Note: we can't reasonably support the 'if you need to put both ' and " in a string, concatenate
+# strings wrapped by the other delimiter' XPath trick, especially as simple XPath.
+_XPSTR = "('(?:.*)'|\"(?:.*)\")"
+
+_RE_SPLITSIMPLELAST = re.compile("^(.*)/(" + _NSIDENT + ")$")
+_RE_SPLITSIMPLELASTEQVALUE = re.compile("^(.*)/(" + _NSIDENT + ")/text\\(\\)=" + _XPSTR + "$")
+_RE_SPLITSIMPLEATTRLAST = re.compile("^(.*)/(@(?:" + _NSIDENT + "))$")
+_RE_SPLITSIMPLEATTRLASTEQVALUE = re.compile("^(.*)/(@(?:" + _NSIDENT + "))=" + _XPSTR + "$")
+_RE_SPLITSUBLAST = re.compile("^(.*)/(" + _NSIDENT + ")\\[(.*)\\]$")
+_RE_SPLITONLYEQVALUE = re.compile("^(.*)/text\\(\\)=" + _XPSTR + "$")
+
+
+def has_changed(doc):
+ orig_obj = etree.tostring(objectify.fromstring(etree.tostring(orig_doc)))
+ obj = etree.tostring(objectify.fromstring(etree.tostring(doc)))
+ return (orig_obj != obj)
+
+
+def do_print_match(module, tree, xpath, namespaces):
+ match = tree.xpath(xpath, namespaces=namespaces)
+ match_xpaths = []
+ for m in match:
+ match_xpaths.append(tree.getpath(m))
+ match_str = json.dumps(match_xpaths)
+ msg = "selector '%s' match: %s" % (xpath, match_str)
+ finish(module, tree, xpath, namespaces, changed=False, msg=msg)
+
+
+def count_nodes(module, tree, xpath, namespaces):
+ """ Return the count of nodes matching the xpath """
+ hits = tree.xpath("count(/%s)" % xpath, namespaces=namespaces)
+ msg = "found %d nodes" % hits
+ finish(module, tree, xpath, namespaces, changed=False, msg=msg, hitcount=int(hits))
+
+
+def is_node(tree, xpath, namespaces):
+ """ Test if a given xpath matches anything and if that match is a node.
+
+ For now we just assume you're only searching for one specific thing."""
+ if xpath_matches(tree, xpath, namespaces):
+ # OK, it found something
+ match = tree.xpath(xpath, namespaces=namespaces)
+ if isinstance(match[0], etree._Element):
+ return True
+
+ return False
+
+
+def is_attribute(tree, xpath, namespaces):
+ """ Test if a given xpath matches and that match is an attribute
+
+ An xpath attribute search will only match one item"""
+ if xpath_matches(tree, xpath, namespaces):
+ match = tree.xpath(xpath, namespaces=namespaces)
+ if isinstance(match[0], etree._ElementStringResult):
+ return True
+ elif isinstance(match[0], etree._ElementUnicodeResult):
+ return True
+ return False
+
+
+def xpath_matches(tree, xpath, namespaces):
+ """ Test if a node exists """
+ if tree.xpath(xpath, namespaces=namespaces):
+ return True
+ return False
+
+
+def delete_xpath_target(module, tree, xpath, namespaces):
+ """ Delete an attribute or element from a tree """
+ try:
+ for result in tree.xpath(xpath, namespaces=namespaces):
+ # Get the xpath for this result
+ if is_attribute(tree, xpath, namespaces):
+ # Delete an attribute
+ parent = result.getparent()
+ # Pop this attribute match out of the parent
+ # node's 'attrib' dict by using this match's
+ # 'attrname' attribute for the key
+ parent.attrib.pop(result.attrname)
+ elif is_node(tree, xpath, namespaces):
+ # Delete an element
+ result.getparent().remove(result)
+ else:
+ raise Exception("Impossible error")
+ except Exception as e:
+ module.fail_json(msg="Couldn't delete xpath target: %s (%s)" % (xpath, e))
+ else:
+ finish(module, tree, xpath, namespaces, changed=True)
+
+
+def replace_children_of(children, match):
+ for element in list(match):
+ match.remove(element)
+ match.extend(children)
+
+
+def set_target_children_inner(module, tree, xpath, namespaces, children, in_type):
+ matches = tree.xpath(xpath, namespaces=namespaces)
+
+ # Create a list of our new children
+ children = children_to_nodes(module, children, in_type)
+ children_as_string = [etree.tostring(c) for c in children]
+
+ changed = False
+
+ # xpaths always return matches as a list, so....
+ for match in matches:
+ # Check if elements differ
+ if len(list(match)) == len(children):
+ for idx, element in enumerate(list(match)):
+ if etree.tostring(element) != children_as_string[idx]:
+ replace_children_of(children, match)
+ changed = True
+ break
+ else:
+ replace_children_of(children, match)
+ changed = True
+
+ return changed
+
+
+def set_target_children(module, tree, xpath, namespaces, children, in_type):
+ changed = set_target_children_inner(module, tree, xpath, namespaces, children, in_type)
+ # Write it out
+ finish(module, tree, xpath, namespaces, changed=changed)
+
+
+def add_target_children(module, tree, xpath, namespaces, children, in_type, insertbefore, insertafter):
+ if is_node(tree, xpath, namespaces):
+ new_kids = children_to_nodes(module, children, in_type)
+ if insertbefore or insertafter:
+ insert_target_children(tree, xpath, namespaces, new_kids, insertbefore, insertafter)
+ else:
+ for node in tree.xpath(xpath, namespaces=namespaces):
+ node.extend(new_kids)
+ finish(module, tree, xpath, namespaces, changed=True)
+ else:
+ finish(module, tree, xpath, namespaces)
+
+
+def insert_target_children(tree, xpath, namespaces, children, insertbefore, insertafter):
+ """
+ Insert the given children before or after the given xpath. If insertbefore is True, it is inserted before the
+ first xpath hit, with insertafter, it is inserted after the last xpath hit.
+ """
+ insert_target = tree.xpath(xpath, namespaces=namespaces)
+ loc_index = 0 if insertbefore else -1
+ index_in_parent = insert_target[loc_index].getparent().index(insert_target[loc_index])
+ parent = insert_target[0].getparent()
+ if insertafter:
+ index_in_parent += 1
+ for child in children:
+ parent.insert(index_in_parent, child)
+ index_in_parent += 1
+
+
+def _extract_xpstr(g):
+ return g[1:-1]
+
+
+def split_xpath_last(xpath):
+ """split an XPath of the form /foo/bar/baz into /foo/bar and baz"""
+ xpath = xpath.strip()
+ m = _RE_SPLITSIMPLELAST.match(xpath)
+ if m:
+ # requesting an element to exist
+ return (m.group(1), [(m.group(2), None)])
+ m = _RE_SPLITSIMPLELASTEQVALUE.match(xpath)
+ if m:
+ # requesting an element to exist with an inner text
+ return (m.group(1), [(m.group(2), _extract_xpstr(m.group(3)))])
+
+ m = _RE_SPLITSIMPLEATTRLAST.match(xpath)
+ if m:
+ # requesting an attribute to exist
+ return (m.group(1), [(m.group(2), None)])
+ m = _RE_SPLITSIMPLEATTRLASTEQVALUE.match(xpath)
+ if m:
+ # requesting an attribute to exist with a value
+ return (m.group(1), [(m.group(2), _extract_xpstr(m.group(3)))])
+
+ m = _RE_SPLITSUBLAST.match(xpath)
+ if m:
+ content = [x.strip() for x in m.group(3).split(" and ")]
+ return (m.group(1), [('/' + m.group(2), content)])
+
+ m = _RE_SPLITONLYEQVALUE.match(xpath)
+ if m:
+ # requesting a change of inner text
+ return (m.group(1), [("", _extract_xpstr(m.group(2)))])
+ return (xpath, [])
+
+
+def nsnameToClark(name, namespaces):
+ if ":" in name:
+ (nsname, rawname) = name.split(":")
+ # return "{{%s}}%s" % (namespaces[nsname], rawname)
+ return "{{{0}}}{1}".format(namespaces[nsname], rawname)
+
+ # no namespace name here
+ return name
+
+
+def check_or_make_target(module, tree, xpath, namespaces):
+ (inner_xpath, changes) = split_xpath_last(xpath)
+ if (inner_xpath == xpath) or (changes is None):
+ module.fail_json(msg="Can't process Xpath %s in order to spawn nodes! tree is %s" %
+ (xpath, etree.tostring(tree, pretty_print=True)))
+ return False
+
+ changed = False
+
+ if not is_node(tree, inner_xpath, namespaces):
+ changed = check_or_make_target(module, tree, inner_xpath, namespaces)
+
+ # we test again after calling check_or_make_target
+ if is_node(tree, inner_xpath, namespaces) and changes:
+ for (eoa, eoa_value) in changes:
+ if eoa and eoa[0] != '@' and eoa[0] != '/':
+ # implicitly creating an element
+ new_kids = children_to_nodes(module, [nsnameToClark(eoa, namespaces)], "yaml")
+ if eoa_value:
+ for nk in new_kids:
+ nk.text = eoa_value
+
+ for node in tree.xpath(inner_xpath, namespaces=namespaces):
+ node.extend(new_kids)
+ changed = True
+ # module.fail_json(msg="now tree=%s" % etree.tostring(tree, pretty_print=True))
+ elif eoa and eoa[0] == '/':
+ element = eoa[1:]
+ new_kids = children_to_nodes(module, [nsnameToClark(element, namespaces)], "yaml")
+ for node in tree.xpath(inner_xpath, namespaces=namespaces):
+ node.extend(new_kids)
+ for nk in new_kids:
+ for subexpr in eoa_value:
+ # module.fail_json(msg="element=%s subexpr=%s node=%s now tree=%s" %
+ # (element, subexpr, etree.tostring(node, pretty_print=True), etree.tostring(tree, pretty_print=True))
+ check_or_make_target(module, nk, "./" + subexpr, namespaces)
+ changed = True
+
+ # module.fail_json(msg="now tree=%s" % etree.tostring(tree, pretty_print=True))
+ elif eoa == "":
+ for node in tree.xpath(inner_xpath, namespaces=namespaces):
+ if (node.text != eoa_value):
+ node.text = eoa_value
+ changed = True
+
+ elif eoa and eoa[0] == '@':
+ attribute = nsnameToClark(eoa[1:], namespaces)
+
+ for element in tree.xpath(inner_xpath, namespaces=namespaces):
+ changing = (attribute not in element.attrib or element.attrib[attribute] != eoa_value)
+
+ if changing:
+ changed = changed or changing
+ if eoa_value is None:
+ value = ""
+ else:
+ value = eoa_value
+ element.attrib[attribute] = value
+
+ # module.fail_json(msg="arf %s changing=%s as curval=%s changed tree=%s" %
+ # (xpath, changing, etree.tostring(tree, changing, element[attribute], pretty_print=True)))
+
+ else:
+ module.fail_json(msg="unknown tree transformation=%s" % etree.tostring(tree, pretty_print=True))
+
+ return changed
+
+
+def ensure_xpath_exists(module, tree, xpath, namespaces):
+ changed = False
+
+ if not is_node(tree, xpath, namespaces):
+ changed = check_or_make_target(module, tree, xpath, namespaces)
+
+ finish(module, tree, xpath, namespaces, changed)
+
+
+def set_target_inner(module, tree, xpath, namespaces, attribute, value):
+ changed = False
+
+ try:
+ if not is_node(tree, xpath, namespaces):
+ changed = check_or_make_target(module, tree, xpath, namespaces)
+ except Exception as e:
+ missing_namespace = ""
+ # NOTE: This checks only the namespaces defined in root element!
+ # TODO: Implement a more robust check to check for child namespaces' existence
+ if tree.getroot().nsmap and ":" not in xpath:
+ missing_namespace = "XML document has namespace(s) defined, but no namespace prefix(es) used in xpath!\n"
+ module.fail_json(msg="%sXpath %s causes a failure: %s\n -- tree is %s" %
+ (missing_namespace, xpath, e, etree.tostring(tree, pretty_print=True)), exception=traceback.format_exc())
+
+ if not is_node(tree, xpath, namespaces):
+ module.fail_json(msg="Xpath %s does not reference a node! tree is %s" %
+ (xpath, etree.tostring(tree, pretty_print=True)))
+
+ for element in tree.xpath(xpath, namespaces=namespaces):
+ if not attribute:
+ changed = changed or (element.text != value)
+ if element.text != value:
+ element.text = value
+ else:
+ changed = changed or (element.get(attribute) != value)
+ if ":" in attribute:
+ attr_ns, attr_name = attribute.split(":")
+ # attribute = "{{%s}}%s" % (namespaces[attr_ns], attr_name)
+ attribute = "{{{0}}}{1}".format(namespaces[attr_ns], attr_name)
+ if element.get(attribute) != value:
+ element.set(attribute, value)
+
+ return changed
+
+
+def set_target(module, tree, xpath, namespaces, attribute, value):
+ changed = set_target_inner(module, tree, xpath, namespaces, attribute, value)
+ finish(module, tree, xpath, namespaces, changed)
+
+
+def get_element_text(module, tree, xpath, namespaces):
+ if not is_node(tree, xpath, namespaces):
+ module.fail_json(msg="Xpath %s does not reference a node!" % xpath)
+
+ elements = []
+ for element in tree.xpath(xpath, namespaces=namespaces):
+ elements.append({element.tag: element.text})
+
+ finish(module, tree, xpath, namespaces, changed=False, msg=len(elements), hitcount=len(elements), matches=elements)
+
+
+def get_element_attr(module, tree, xpath, namespaces):
+ if not is_node(tree, xpath, namespaces):
+ module.fail_json(msg="Xpath %s does not reference a node!" % xpath)
+
+ elements = []
+ for element in tree.xpath(xpath, namespaces=namespaces):
+ child = {}
+ for key in element.keys():
+ value = element.get(key)
+ child.update({key: value})
+ elements.append({element.tag: child})
+
+ finish(module, tree, xpath, namespaces, changed=False, msg=len(elements), hitcount=len(elements), matches=elements)
+
+
+def child_to_element(module, child, in_type):
+ if in_type == 'xml':
+ infile = BytesIO(to_bytes(child, errors='surrogate_or_strict'))
+
+ try:
+ parser = etree.XMLParser()
+ node = etree.parse(infile, parser)
+ return node.getroot()
+ except etree.XMLSyntaxError as e:
+ module.fail_json(msg="Error while parsing child element: %s" % e)
+ elif in_type == 'yaml':
+ if isinstance(child, string_types):
+ return etree.Element(child)
+ elif isinstance(child, MutableMapping):
+ if len(child) > 1:
+ module.fail_json(msg="Can only create children from hashes with one key")
+
+ (key, value) = next(iteritems(child))
+ if isinstance(value, MutableMapping):
+ children = value.pop('_', None)
+
+ node = etree.Element(key, value)
+
+ if children is not None:
+ if not isinstance(children, list):
+ module.fail_json(msg="Invalid children type: %s, must be list." % type(children))
+
+ subnodes = children_to_nodes(module, children)
+ node.extend(subnodes)
+ else:
+ node = etree.Element(key)
+ node.text = value
+ return node
+ else:
+ module.fail_json(msg="Invalid child type: %s. Children must be either strings or hashes." % type(child))
+ else:
+ module.fail_json(msg="Invalid child input type: %s. Type must be either xml or yaml." % in_type)
+
+
+def children_to_nodes(module=None, children=None, type='yaml'):
+ """turn a str/hash/list of str&hash into a list of elements"""
+ children = [] if children is None else children
+
+ return [child_to_element(module, child, type) for child in children]
+
+
+def make_pretty(module, tree):
+ xml_string = etree.tostring(tree, xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print'])
+
+ result = dict(
+ changed=False,
+ )
+
+ if module.params['path']:
+ xml_file = module.params['path']
+ with open(xml_file, 'rb') as xml_content:
+ if xml_string != xml_content.read():
+ result['changed'] = True
+ if not module.check_mode:
+ if module.params['backup']:
+ result['backup_file'] = module.backup_local(module.params['path'])
+ tree.write(xml_file, xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print'])
+
+ elif module.params['xmlstring']:
+ result['xmlstring'] = xml_string
+ # NOTE: Modifying a string is not considered a change !
+ if xml_string != module.params['xmlstring']:
+ result['changed'] = True
+
+ module.exit_json(**result)
+
+
+def finish(module, tree, xpath, namespaces, changed=False, msg='', hitcount=0, matches=tuple()):
+
+ result = dict(
+ actions=dict(
+ xpath=xpath,
+ namespaces=namespaces,
+ state=module.params['state']
+ ),
+ changed=has_changed(tree),
+ )
+
+ if module.params['count'] or hitcount:
+ result['count'] = hitcount
+
+ if module.params['print_match'] or matches:
+ result['matches'] = matches
+
+ if msg:
+ result['msg'] = msg
+
+ if result['changed']:
+ if module._diff:
+ result['diff'] = dict(
+ before=etree.tostring(orig_doc, xml_declaration=True, encoding='UTF-8', pretty_print=True),
+ after=etree.tostring(tree, xml_declaration=True, encoding='UTF-8', pretty_print=True),
+ )
+
+ if module.params['path'] and not module.check_mode:
+ if module.params['backup']:
+ result['backup_file'] = module.backup_local(module.params['path'])
+ tree.write(module.params['path'], xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print'])
+
+ if module.params['xmlstring']:
+ result['xmlstring'] = etree.tostring(tree, xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print'])
+
+ module.exit_json(**result)
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ path=dict(type='path', aliases=['dest', 'file']),
+ xmlstring=dict(type='str'),
+ xpath=dict(type='str'),
+ namespaces=dict(type='dict', default={}),
+ state=dict(type='str', default='present', choices=['absent', 'present'], aliases=['ensure']),
+ value=dict(type='raw'),
+ attribute=dict(type='raw'),
+ add_children=dict(type='list'),
+ set_children=dict(type='list'),
+ count=dict(type='bool', default=False),
+ print_match=dict(type='bool', default=False),
+ pretty_print=dict(type='bool', default=False),
+ content=dict(type='str', choices=['attribute', 'text']),
+ input_type=dict(type='str', default='yaml', choices=['xml', 'yaml']),
+ backup=dict(type='bool', default=False),
+ strip_cdata_tags=dict(type='bool', default=False),
+ insertbefore=dict(type='bool', default=False),
+ insertafter=dict(type='bool', default=False),
+ ),
+ supports_check_mode=True,
+ required_by=dict(
+ add_children=['xpath'],
+ # TODO: Reinstate this in Ansible v2.12 when we have deprecated the incorrect use below
+ # attribute=['value'],
+ content=['xpath'],
+ set_children=['xpath'],
+ value=['xpath'],
+ ),
+ required_if=[
+ ['count', True, ['xpath']],
+ ['print_match', True, ['xpath']],
+ ['insertbefore', True, ['xpath']],
+ ['insertafter', True, ['xpath']],
+ ],
+ required_one_of=[
+ ['path', 'xmlstring'],
+ ['add_children', 'content', 'count', 'pretty_print', 'print_match', 'set_children', 'value'],
+ ],
+ mutually_exclusive=[
+ ['add_children', 'content', 'count', 'print_match', 'set_children', 'value'],
+ ['path', 'xmlstring'],
+ ['insertbefore', 'insertafter'],
+ ],
+ )
+
+ xml_file = module.params['path']
+ xml_string = module.params['xmlstring']
+ xpath = module.params['xpath']
+ namespaces = module.params['namespaces']
+ state = module.params['state']
+ value = json_dict_bytes_to_unicode(module.params['value'])
+ attribute = module.params['attribute']
+ set_children = json_dict_bytes_to_unicode(module.params['set_children'])
+ add_children = json_dict_bytes_to_unicode(module.params['add_children'])
+ pretty_print = module.params['pretty_print']
+ content = module.params['content']
+ input_type = module.params['input_type']
+ print_match = module.params['print_match']
+ count = module.params['count']
+ backup = module.params['backup']
+ strip_cdata_tags = module.params['strip_cdata_tags']
+ insertbefore = module.params['insertbefore']
+ insertafter = module.params['insertafter']
+
+ # Check if we have lxml 2.3.0 or newer installed
+ if not HAS_LXML:
+ module.fail_json(msg=missing_required_lib("lxml"), exception=LXML_IMP_ERR)
+ elif LooseVersion('.'.join(to_native(f) for f in etree.LXML_VERSION)) < LooseVersion('2.3.0'):
+ module.fail_json(msg='The xml ansible module requires lxml 2.3.0 or newer installed on the managed machine')
+ elif LooseVersion('.'.join(to_native(f) for f in etree.LXML_VERSION)) < LooseVersion('3.0.0'):
+ module.warn('Using lxml version lower than 3.0.0 does not guarantee predictable element attribute order.')
+
+ # Report wrongly used attribute parameter when using content=attribute
+ # TODO: Remove this in Ansible v2.12 (and reinstate strict parameter test above) and remove the integration test example
+ if content == 'attribute' and attribute is not None:
+ module.deprecate("Parameter 'attribute=%s' is ignored when using 'content=attribute' only 'xpath' is used. Please remove entry." % attribute,
+ '2.12', collection_name='ansible.builtin')
+
+ # Check if the file exists
+ if xml_string:
+ infile = BytesIO(to_bytes(xml_string, errors='surrogate_or_strict'))
+ elif os.path.isfile(xml_file):
+ infile = open(xml_file, 'rb')
+ else:
+ module.fail_json(msg="The target XML source '%s' does not exist." % xml_file)
+
+ # Parse and evaluate xpath expression
+ if xpath is not None:
+ try:
+ etree.XPath(xpath)
+ except etree.XPathSyntaxError as e:
+ module.fail_json(msg="Syntax error in xpath expression: %s (%s)" % (xpath, e))
+ except etree.XPathEvalError as e:
+ module.fail_json(msg="Evaluation error in xpath expression: %s (%s)" % (xpath, e))
+
+ # Try to parse in the target XML file
+ try:
+ parser = etree.XMLParser(remove_blank_text=pretty_print, strip_cdata=strip_cdata_tags)
+ doc = etree.parse(infile, parser)
+ except etree.XMLSyntaxError as e:
+ module.fail_json(msg="Error while parsing document: %s (%s)" % (xml_file or 'xml_string', e))
+
+ # Ensure we have the original copy to compare
+ global orig_doc
+ orig_doc = copy.deepcopy(doc)
+
+ if print_match:
+ do_print_match(module, doc, xpath, namespaces)
+
+ if count:
+ count_nodes(module, doc, xpath, namespaces)
+
+ if content == 'attribute':
+ get_element_attr(module, doc, xpath, namespaces)
+ elif content == 'text':
+ get_element_text(module, doc, xpath, namespaces)
+
+ # File exists:
+ if state == 'absent':
+ # - absent: delete xpath target
+ delete_xpath_target(module, doc, xpath, namespaces)
+
+ # - present: carry on
+
+ # children && value both set?: should have already aborted by now
+ # add_children && set_children both set?: should have already aborted by now
+
+ # set_children set?
+ if set_children:
+ set_target_children(module, doc, xpath, namespaces, set_children, input_type)
+
+ # add_children set?
+ if add_children:
+ add_target_children(module, doc, xpath, namespaces, add_children, input_type, insertbefore, insertafter)
+
+ # No?: Carry on
+
+ # Is the xpath target an attribute selector?
+ if value is not None:
+ set_target(module, doc, xpath, namespaces, attribute, value)
+
+ # If an xpath was provided, we need to do something with the data
+ if xpath is not None:
+ ensure_xpath_exists(module, doc, xpath, namespaces)
+
+ # Otherwise only reformat the xml data?
+ if pretty_print:
+ make_pretty(module, doc)
+
+ module.fail_json(msg="Don't know what to do")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/support/integration/plugins/modules/zypper.py b/test/support/integration/plugins/modules/zypper.py
new file mode 100644
index 00000000..bfb31819
--- /dev/null
+++ b/test/support/integration/plugins/modules/zypper.py
@@ -0,0 +1,540 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# (c) 2013, Patrick Callahan <pmc@patrickcallahan.com>
+# based on
+# openbsd_pkg
+# (c) 2013
+# Patrik Lundin <patrik.lundin.swe@gmail.com>
+#
+# yum
+# (c) 2012, Red Hat, Inc
+# Written by Seth Vidal <skvidal at fedoraproject.org>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+
+DOCUMENTATION = '''
+---
+module: zypper
+author:
+ - "Patrick Callahan (@dirtyharrycallahan)"
+ - "Alexander Gubin (@alxgu)"
+ - "Thomas O'Donnell (@andytom)"
+ - "Robin Roth (@robinro)"
+ - "Andrii Radyk (@AnderEnder)"
+version_added: "1.2"
+short_description: Manage packages on SUSE and openSUSE
+description:
+ - Manage packages on SUSE and openSUSE using the zypper and rpm tools.
+options:
+ name:
+ description:
+ - Package name C(name) or package specifier or a list of either.
+ - Can include a version like C(name=1.0), C(name>3.4) or C(name<=2.7). If a version is given, C(oldpackage) is implied and zypper is allowed to
+ update the package within the version range given.
+ - You can also pass a url or a local path to a rpm file.
+ - When using state=latest, this can be '*', which updates all installed packages.
+ required: true
+ aliases: [ 'pkg' ]
+ state:
+ description:
+ - C(present) will make sure the package is installed.
+ C(latest) will make sure the latest version of the package is installed.
+ C(absent) will make sure the specified package is not installed.
+ C(dist-upgrade) will make sure the latest version of all installed packages from all enabled repositories is installed.
+ - When using C(dist-upgrade), I(name) should be C('*').
+ required: false
+ choices: [ present, latest, absent, dist-upgrade ]
+ default: "present"
+ type:
+ description:
+ - The type of package to be operated on.
+ required: false
+ choices: [ package, patch, pattern, product, srcpackage, application ]
+ default: "package"
+ version_added: "2.0"
+ extra_args_precommand:
+ version_added: "2.6"
+ required: false
+ description:
+ - Add additional global target options to C(zypper).
+ - Options should be supplied in a single line as if given in the command line.
+ disable_gpg_check:
+ description:
+ - Whether to disable to GPG signature checking of the package
+ signature being installed. Has an effect only if state is
+ I(present) or I(latest).
+ required: false
+ default: "no"
+ type: bool
+ disable_recommends:
+ version_added: "1.8"
+ description:
+ - Corresponds to the C(--no-recommends) option for I(zypper). Default behavior (C(yes)) modifies zypper's default behavior; C(no) does
+ install recommended packages.
+ required: false
+ default: "yes"
+ type: bool
+ force:
+ version_added: "2.2"
+ description:
+ - Adds C(--force) option to I(zypper). Allows to downgrade packages and change vendor or architecture.
+ required: false
+ default: "no"
+ type: bool
+ force_resolution:
+ version_added: "2.10"
+ description:
+ - Adds C(--force-resolution) option to I(zypper). Allows to (un)install packages with conflicting requirements (resolver will choose a solution).
+ required: false
+ default: "no"
+ type: bool
+ update_cache:
+ version_added: "2.2"
+ description:
+ - Run the equivalent of C(zypper refresh) before the operation. Disabled in check mode.
+ required: false
+ default: "no"
+ type: bool
+ aliases: [ "refresh" ]
+ oldpackage:
+ version_added: "2.2"
+ description:
+ - Adds C(--oldpackage) option to I(zypper). Allows to downgrade packages with less side-effects than force. This is implied as soon as a
+ version is specified as part of the package name.
+ required: false
+ default: "no"
+ type: bool
+ extra_args:
+ version_added: "2.4"
+ required: false
+ description:
+ - Add additional options to C(zypper) command.
+ - Options should be supplied in a single line as if given in the command line.
+notes:
+ - When used with a `loop:` each package will be processed individually,
+ it is much more efficient to pass the list directly to the `name` option.
+# informational: requirements for nodes
+requirements:
+ - "zypper >= 1.0 # included in openSUSE >= 11.1 or SUSE Linux Enterprise Server/Desktop >= 11.0"
+ - python-xml
+ - rpm
+'''
+
+EXAMPLES = '''
+# Install "nmap"
+- zypper:
+ name: nmap
+ state: present
+
+# Install apache2 with recommended packages
+- zypper:
+ name: apache2
+ state: present
+ disable_recommends: no
+
+# Apply a given patch
+- zypper:
+ name: openSUSE-2016-128
+ state: present
+ type: patch
+
+# Remove the "nmap" package
+- zypper:
+ name: nmap
+ state: absent
+
+# Install the nginx rpm from a remote repo
+- zypper:
+ name: 'http://nginx.org/packages/sles/12/x86_64/RPMS/nginx-1.8.0-1.sles12.ngx.x86_64.rpm'
+ state: present
+
+# Install local rpm file
+- zypper:
+ name: /tmp/fancy-software.rpm
+ state: present
+
+# Update all packages
+- zypper:
+ name: '*'
+ state: latest
+
+# Apply all available patches
+- zypper:
+ name: '*'
+ state: latest
+ type: patch
+
+# Perform a dist-upgrade with additional arguments
+- zypper:
+ name: '*'
+ state: dist-upgrade
+ extra_args: '--no-allow-vendor-change --allow-arch-change'
+
+# Refresh repositories and update package "openssl"
+- zypper:
+ name: openssl
+ state: present
+ update_cache: yes
+
+# Install specific version (possible comparisons: <, >, <=, >=, =)
+- zypper:
+ name: 'docker>=1.10'
+ state: present
+
+# Wait 20 seconds to acquire the lock before failing
+- zypper:
+ name: mosh
+ state: present
+ environment:
+ ZYPP_LOCK_TIMEOUT: 20
+'''
+
+import xml
+import re
+from xml.dom.minidom import parseString as parseXML
+from ansible.module_utils.six import iteritems
+from ansible.module_utils._text import to_native
+
+# import module snippets
+from ansible.module_utils.basic import AnsibleModule
+
+
+class Package:
+ def __init__(self, name, prefix, version):
+ self.name = name
+ self.prefix = prefix
+ self.version = version
+ self.shouldinstall = (prefix == '+')
+
+ def __str__(self):
+ return self.prefix + self.name + self.version
+
+
+def split_name_version(name):
+ """splits of the package name and desired version
+
+ example formats:
+ - docker>=1.10
+ - apache=2.4
+
+ Allowed version specifiers: <, >, <=, >=, =
+ Allowed version format: [0-9.-]*
+
+ Also allows a prefix indicating remove "-", "~" or install "+"
+ """
+
+ prefix = ''
+ if name[0] in ['-', '~', '+']:
+ prefix = name[0]
+ name = name[1:]
+ if prefix == '~':
+ prefix = '-'
+
+ version_check = re.compile('^(.*?)((?:<|>|<=|>=|=)[0-9.-]*)?$')
+ try:
+ reres = version_check.match(name)
+ name, version = reres.groups()
+ if version is None:
+ version = ''
+ return prefix, name, version
+ except Exception:
+ return prefix, name, ''
+
+
+def get_want_state(names, remove=False):
+ packages = []
+ urls = []
+ for name in names:
+ if '://' in name or name.endswith('.rpm'):
+ urls.append(name)
+ else:
+ prefix, pname, version = split_name_version(name)
+ if prefix not in ['-', '+']:
+ if remove:
+ prefix = '-'
+ else:
+ prefix = '+'
+ packages.append(Package(pname, prefix, version))
+ return packages, urls
+
+
+def get_installed_state(m, packages):
+ "get installed state of packages"
+
+ cmd = get_cmd(m, 'search')
+ cmd.extend(['--match-exact', '--details', '--installed-only'])
+ cmd.extend([p.name for p in packages])
+ return parse_zypper_xml(m, cmd, fail_not_found=False)[0]
+
+
+def parse_zypper_xml(m, cmd, fail_not_found=True, packages=None):
+ rc, stdout, stderr = m.run_command(cmd, check_rc=False)
+
+ try:
+ dom = parseXML(stdout)
+ except xml.parsers.expat.ExpatError as exc:
+ m.fail_json(msg="Failed to parse zypper xml output: %s" % to_native(exc),
+ rc=rc, stdout=stdout, stderr=stderr, cmd=cmd)
+
+ if rc == 104:
+ # exit code 104 is ZYPPER_EXIT_INF_CAP_NOT_FOUND (no packages found)
+ if fail_not_found:
+ errmsg = dom.getElementsByTagName('message')[-1].childNodes[0].data
+ m.fail_json(msg=errmsg, rc=rc, stdout=stdout, stderr=stderr, cmd=cmd)
+ else:
+ return {}, rc, stdout, stderr
+ elif rc in [0, 106, 103]:
+ # zypper exit codes
+ # 0: success
+ # 106: signature verification failed
+ # 103: zypper was upgraded, run same command again
+ if packages is None:
+ firstrun = True
+ packages = {}
+ solvable_list = dom.getElementsByTagName('solvable')
+ for solvable in solvable_list:
+ name = solvable.getAttribute('name')
+ packages[name] = {}
+ packages[name]['version'] = solvable.getAttribute('edition')
+ packages[name]['oldversion'] = solvable.getAttribute('edition-old')
+ status = solvable.getAttribute('status')
+ packages[name]['installed'] = status == "installed"
+ packages[name]['group'] = solvable.parentNode.nodeName
+ if rc == 103 and firstrun:
+ # if this was the first run and it failed with 103
+ # run zypper again with the same command to complete update
+ return parse_zypper_xml(m, cmd, fail_not_found=fail_not_found, packages=packages)
+
+ return packages, rc, stdout, stderr
+ m.fail_json(msg='Zypper run command failed with return code %s.' % rc, rc=rc, stdout=stdout, stderr=stderr, cmd=cmd)
+
+
+def get_cmd(m, subcommand):
+ "puts together the basic zypper command arguments with those passed to the module"
+ is_install = subcommand in ['install', 'update', 'patch', 'dist-upgrade']
+ is_refresh = subcommand == 'refresh'
+ cmd = ['/usr/bin/zypper', '--quiet', '--non-interactive', '--xmlout']
+ if m.params['extra_args_precommand']:
+ args_list = m.params['extra_args_precommand'].split()
+ cmd.extend(args_list)
+ # add global options before zypper command
+ if (is_install or is_refresh) and m.params['disable_gpg_check']:
+ cmd.append('--no-gpg-checks')
+
+ if subcommand == 'search':
+ cmd.append('--disable-repositories')
+
+ cmd.append(subcommand)
+ if subcommand not in ['patch', 'dist-upgrade'] and not is_refresh:
+ cmd.extend(['--type', m.params['type']])
+ if m.check_mode and subcommand != 'search':
+ cmd.append('--dry-run')
+ if is_install:
+ cmd.append('--auto-agree-with-licenses')
+ if m.params['disable_recommends']:
+ cmd.append('--no-recommends')
+ if m.params['force']:
+ cmd.append('--force')
+ if m.params['force_resolution']:
+ cmd.append('--force-resolution')
+ if m.params['oldpackage']:
+ cmd.append('--oldpackage')
+ if m.params['extra_args']:
+ args_list = m.params['extra_args'].split(' ')
+ cmd.extend(args_list)
+
+ return cmd
+
+
+def set_diff(m, retvals, result):
+ # TODO: if there is only one package, set before/after to version numbers
+ packages = {'installed': [], 'removed': [], 'upgraded': []}
+ if result:
+ for p in result:
+ group = result[p]['group']
+ if group == 'to-upgrade':
+ versions = ' (' + result[p]['oldversion'] + ' => ' + result[p]['version'] + ')'
+ packages['upgraded'].append(p + versions)
+ elif group == 'to-install':
+ packages['installed'].append(p)
+ elif group == 'to-remove':
+ packages['removed'].append(p)
+
+ output = ''
+ for state in packages:
+ if packages[state]:
+ output += state + ': ' + ', '.join(packages[state]) + '\n'
+ if 'diff' not in retvals:
+ retvals['diff'] = {}
+ if 'prepared' not in retvals['diff']:
+ retvals['diff']['prepared'] = output
+ else:
+ retvals['diff']['prepared'] += '\n' + output
+
+
+def package_present(m, name, want_latest):
+ "install and update (if want_latest) the packages in name_install, while removing the packages in name_remove"
+ retvals = {'rc': 0, 'stdout': '', 'stderr': ''}
+ packages, urls = get_want_state(name)
+
+ # add oldpackage flag when a version is given to allow downgrades
+ if any(p.version for p in packages):
+ m.params['oldpackage'] = True
+
+ if not want_latest:
+ # for state=present: filter out already installed packages
+ # if a version is given leave the package in to let zypper handle the version
+ # resolution
+ packageswithoutversion = [p for p in packages if not p.version]
+ prerun_state = get_installed_state(m, packageswithoutversion)
+ # generate lists of packages to install or remove
+ packages = [p for p in packages if p.shouldinstall != (p.name in prerun_state)]
+
+ if not packages and not urls:
+ # nothing to install/remove and nothing to update
+ return None, retvals
+
+ # zypper install also updates packages
+ cmd = get_cmd(m, 'install')
+ cmd.append('--')
+ cmd.extend(urls)
+ # pass packages to zypper
+ # allow for + or - prefixes in install/remove lists
+ # also add version specifier if given
+ # do this in one zypper run to allow for dependency-resolution
+ # for example "-exim postfix" runs without removing packages depending on mailserver
+ cmd.extend([str(p) for p in packages])
+
+ retvals['cmd'] = cmd
+ result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd)
+
+ return result, retvals
+
+
+def package_update_all(m):
+ "run update or patch on all available packages"
+
+ retvals = {'rc': 0, 'stdout': '', 'stderr': ''}
+ if m.params['type'] == 'patch':
+ cmdname = 'patch'
+ elif m.params['state'] == 'dist-upgrade':
+ cmdname = 'dist-upgrade'
+ else:
+ cmdname = 'update'
+
+ cmd = get_cmd(m, cmdname)
+ retvals['cmd'] = cmd
+ result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd)
+ return result, retvals
+
+
+def package_absent(m, name):
+ "remove the packages in name"
+ retvals = {'rc': 0, 'stdout': '', 'stderr': ''}
+ # Get package state
+ packages, urls = get_want_state(name, remove=True)
+ if any(p.prefix == '+' for p in packages):
+ m.fail_json(msg="Can not combine '+' prefix with state=remove/absent.")
+ if urls:
+ m.fail_json(msg="Can not remove via URL.")
+ if m.params['type'] == 'patch':
+ m.fail_json(msg="Can not remove patches.")
+ prerun_state = get_installed_state(m, packages)
+ packages = [p for p in packages if p.name in prerun_state]
+
+ if not packages:
+ return None, retvals
+
+ cmd = get_cmd(m, 'remove')
+ cmd.extend([p.name + p.version for p in packages])
+
+ retvals['cmd'] = cmd
+ result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd)
+ return result, retvals
+
+
+def repo_refresh(m):
+ "update the repositories"
+ retvals = {'rc': 0, 'stdout': '', 'stderr': ''}
+
+ cmd = get_cmd(m, 'refresh')
+
+ retvals['cmd'] = cmd
+ result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd)
+
+ return retvals
+
+# ===========================================
+# Main control flow
+
+
+def main():
+ module = AnsibleModule(
+ argument_spec=dict(
+ name=dict(required=True, aliases=['pkg'], type='list'),
+ state=dict(required=False, default='present', choices=['absent', 'installed', 'latest', 'present', 'removed', 'dist-upgrade']),
+ type=dict(required=False, default='package', choices=['package', 'patch', 'pattern', 'product', 'srcpackage', 'application']),
+ extra_args_precommand=dict(required=False, default=None),
+ disable_gpg_check=dict(required=False, default='no', type='bool'),
+ disable_recommends=dict(required=False, default='yes', type='bool'),
+ force=dict(required=False, default='no', type='bool'),
+ force_resolution=dict(required=False, default='no', type='bool'),
+ update_cache=dict(required=False, aliases=['refresh'], default='no', type='bool'),
+ oldpackage=dict(required=False, default='no', type='bool'),
+ extra_args=dict(required=False, default=None),
+ ),
+ supports_check_mode=True
+ )
+
+ name = module.params['name']
+ state = module.params['state']
+ update_cache = module.params['update_cache']
+
+ # remove empty strings from package list
+ name = list(filter(None, name))
+
+ # Refresh repositories
+ if update_cache and not module.check_mode:
+ retvals = repo_refresh(module)
+
+ if retvals['rc'] != 0:
+ module.fail_json(msg="Zypper refresh run failed.", **retvals)
+
+ # Perform requested action
+ if name == ['*'] and state in ['latest', 'dist-upgrade']:
+ packages_changed, retvals = package_update_all(module)
+ elif name != ['*'] and state == 'dist-upgrade':
+ module.fail_json(msg="Can not dist-upgrade specific packages.")
+ else:
+ if state in ['absent', 'removed']:
+ packages_changed, retvals = package_absent(module, name)
+ elif state in ['installed', 'present', 'latest']:
+ packages_changed, retvals = package_present(module, name, state == 'latest')
+
+ retvals['changed'] = retvals['rc'] == 0 and bool(packages_changed)
+
+ if module._diff:
+ set_diff(module, retvals, packages_changed)
+
+ if retvals['rc'] != 0:
+ module.fail_json(msg="Zypper run failed.", **retvals)
+
+ if not retvals['changed']:
+ del retvals['stdout']
+ del retvals['stderr']
+
+ module.exit_json(name=name, state=state, update_cache=update_cache, **retvals)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py
new file mode 100644
index 00000000..089b339f
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py
@@ -0,0 +1,40 @@
+#
+# Copyright 2018 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+from ansible_collections.ansible.netcommon.plugins.action.network import (
+ ActionModule as ActionNetworkModule,
+)
+
+
+class ActionModule(ActionNetworkModule):
+ def run(self, tmp=None, task_vars=None):
+ del tmp # tmp no longer has any effect
+
+ self._config_module = True
+ if self._play_context.connection.split(".")[-1] != "network_cli":
+ return {
+ "failed": True,
+ "msg": "Connection type %s is not valid for cli_config module"
+ % self._play_context.connection,
+ }
+
+ return super(ActionModule, self).run(task_vars=task_vars)
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py
new file mode 100644
index 00000000..542dcfef
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py
@@ -0,0 +1,90 @@
+# Copyright: (c) 2015, Ansible Inc,
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import copy
+
+from ansible.errors import AnsibleError
+from ansible.plugins.action import ActionBase
+from ansible.utils.display import Display
+
+display = Display()
+
+
+class ActionModule(ActionBase):
+ def run(self, tmp=None, task_vars=None):
+ del tmp # tmp no longer has any effect
+
+ result = {}
+ play_context = copy.deepcopy(self._play_context)
+ play_context.network_os = self._get_network_os(task_vars)
+ new_task = self._task.copy()
+
+ module = self._get_implementation_module(
+ play_context.network_os, self._task.action
+ )
+ if not module:
+ if self._task.args["fail_on_missing_module"]:
+ result["failed"] = True
+ else:
+ result["failed"] = False
+
+ result["msg"] = (
+ "Could not find implementation module %s for %s"
+ % (self._task.action, play_context.network_os)
+ )
+ return result
+
+ new_task.action = module
+
+ action = self._shared_loader_obj.action_loader.get(
+ play_context.network_os,
+ task=new_task,
+ connection=self._connection,
+ play_context=play_context,
+ loader=self._loader,
+ templar=self._templar,
+ shared_loader_obj=self._shared_loader_obj,
+ )
+ display.vvvv("Running implementation module %s" % module)
+ return action.run(task_vars=task_vars)
+
+ def _get_network_os(self, task_vars):
+ if "network_os" in self._task.args and self._task.args["network_os"]:
+ display.vvvv("Getting network OS from task argument")
+ network_os = self._task.args["network_os"]
+ elif self._play_context.network_os:
+ display.vvvv("Getting network OS from inventory")
+ network_os = self._play_context.network_os
+ elif (
+ "network_os" in task_vars.get("ansible_facts", {})
+ and task_vars["ansible_facts"]["network_os"]
+ ):
+ display.vvvv("Getting network OS from fact")
+ network_os = task_vars["ansible_facts"]["network_os"]
+ else:
+ raise AnsibleError(
+ "ansible_network_os must be specified on this host to use platform agnostic modules"
+ )
+
+ return network_os
+
+ def _get_implementation_module(self, network_os, platform_agnostic_module):
+ module_name = (
+ network_os.split(".")[-1]
+ + "_"
+ + platform_agnostic_module.partition("_")[2]
+ )
+ if "." in network_os:
+ fqcn_module = ".".join(network_os.split(".")[0:-1])
+ implementation_module = fqcn_module + "." + module_name
+ else:
+ implementation_module = module_name
+
+ if implementation_module not in self._shared_loader_obj.module_loader:
+ implementation_module = None
+
+ return implementation_module
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py
new file mode 100644
index 00000000..40205a46
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py
@@ -0,0 +1,199 @@
+# (c) 2018, Ansible Inc,
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import os
+import re
+import uuid
+import hashlib
+
+from ansible.errors import AnsibleError
+from ansible.module_utils._text import to_text, to_bytes
+from ansible.module_utils.connection import Connection, ConnectionError
+from ansible.plugins.action import ActionBase
+from ansible.module_utils.six.moves.urllib.parse import urlsplit
+from ansible.utils.display import Display
+
+display = Display()
+
+
+class ActionModule(ActionBase):
+ def run(self, tmp=None, task_vars=None):
+ socket_path = None
+ self._get_network_os(task_vars)
+ persistent_connection = self._play_context.connection.split(".")[-1]
+
+ result = super(ActionModule, self).run(task_vars=task_vars)
+
+ if persistent_connection != "network_cli":
+ # It is supported only with network_cli
+ result["failed"] = True
+ result["msg"] = (
+ "connection type %s is not valid for net_get module,"
+ " please use fully qualified name of network_cli connection type"
+ % self._play_context.connection
+ )
+ return result
+
+ try:
+ src = self._task.args["src"]
+ except KeyError as exc:
+ return {
+ "failed": True,
+ "msg": "missing required argument: %s" % exc,
+ }
+
+ # Get destination file if specified
+ dest = self._task.args.get("dest")
+
+ if dest is None:
+ dest = self._get_default_dest(src)
+ else:
+ dest = self._handle_dest_path(dest)
+
+ # Get proto
+ proto = self._task.args.get("protocol")
+ if proto is None:
+ proto = "scp"
+
+ if socket_path is None:
+ socket_path = self._connection.socket_path
+
+ conn = Connection(socket_path)
+ sock_timeout = conn.get_option("persistent_command_timeout")
+
+ try:
+ changed = self._handle_existing_file(
+ conn, src, dest, proto, sock_timeout
+ )
+ if changed is False:
+ result["changed"] = changed
+ result["destination"] = dest
+ return result
+ except Exception as exc:
+ result["msg"] = (
+ "Warning: %s idempotency check failed. Check dest" % exc
+ )
+
+ try:
+ conn.get_file(
+ source=src, destination=dest, proto=proto, timeout=sock_timeout
+ )
+ except Exception as exc:
+ result["failed"] = True
+ result["msg"] = "Exception received: %s" % exc
+
+ result["changed"] = changed
+ result["destination"] = dest
+ return result
+
+ def _handle_dest_path(self, dest):
+ working_path = self._get_working_path()
+
+ if os.path.isabs(dest) or urlsplit("dest").scheme:
+ dst = dest
+ else:
+ dst = self._loader.path_dwim_relative(working_path, "", dest)
+
+ return dst
+
+ def _get_src_filename_from_path(self, src_path):
+ filename_list = re.split("/|:", src_path)
+ return filename_list[-1]
+
+ def _get_default_dest(self, src_path):
+ dest_path = self._get_working_path()
+ src_fname = self._get_src_filename_from_path(src_path)
+ filename = "%s/%s" % (dest_path, src_fname)
+ return filename
+
+ def _handle_existing_file(self, conn, source, dest, proto, timeout):
+ """
+ Determines whether the source and destination file match.
+
+ :return: False if source and dest both exist and have matching sha1 sums, True otherwise.
+ """
+ if not os.path.exists(dest):
+ return True
+
+ cwd = self._loader.get_basedir()
+ filename = str(uuid.uuid4())
+ tmp_dest_file = os.path.join(cwd, filename)
+ try:
+ conn.get_file(
+ source=source,
+ destination=tmp_dest_file,
+ proto=proto,
+ timeout=timeout,
+ )
+ except ConnectionError as exc:
+ error = to_text(exc)
+ if error.endswith("No such file or directory"):
+ if os.path.exists(tmp_dest_file):
+ os.remove(tmp_dest_file)
+ return True
+
+ try:
+ with open(tmp_dest_file, "r") as f:
+ new_content = f.read()
+ with open(dest, "r") as f:
+ old_content = f.read()
+ except (IOError, OSError):
+ os.remove(tmp_dest_file)
+ raise
+
+ sha1 = hashlib.sha1()
+ old_content_b = to_bytes(old_content, errors="surrogate_or_strict")
+ sha1.update(old_content_b)
+ checksum_old = sha1.digest()
+
+ sha1 = hashlib.sha1()
+ new_content_b = to_bytes(new_content, errors="surrogate_or_strict")
+ sha1.update(new_content_b)
+ checksum_new = sha1.digest()
+ os.remove(tmp_dest_file)
+ if checksum_old == checksum_new:
+ return False
+ return True
+
+ def _get_working_path(self):
+ cwd = self._loader.get_basedir()
+ if self._task._role is not None:
+ cwd = self._task._role._role_path
+ return cwd
+
+ def _get_network_os(self, task_vars):
+ if "network_os" in self._task.args and self._task.args["network_os"]:
+ display.vvvv("Getting network OS from task argument")
+ network_os = self._task.args["network_os"]
+ elif self._play_context.network_os:
+ display.vvvv("Getting network OS from inventory")
+ network_os = self._play_context.network_os
+ elif (
+ "network_os" in task_vars.get("ansible_facts", {})
+ and task_vars["ansible_facts"]["network_os"]
+ ):
+ display.vvvv("Getting network OS from fact")
+ network_os = task_vars["ansible_facts"]["network_os"]
+ else:
+ raise AnsibleError(
+ "ansible_network_os must be specified on this host"
+ )
+
+ return network_os
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py
new file mode 100644
index 00000000..955329d4
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py
@@ -0,0 +1,235 @@
+# (c) 2018, Ansible Inc,
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import os
+import uuid
+import hashlib
+
+from ansible.errors import AnsibleError
+from ansible.module_utils._text import to_text, to_bytes
+from ansible.module_utils.connection import Connection, ConnectionError
+from ansible.plugins.action import ActionBase
+from ansible.module_utils.six.moves.urllib.parse import urlsplit
+from ansible.utils.display import Display
+
+display = Display()
+
+
+class ActionModule(ActionBase):
+ def run(self, tmp=None, task_vars=None):
+ socket_path = None
+ network_os = self._get_network_os(task_vars).split(".")[-1]
+ persistent_connection = self._play_context.connection.split(".")[-1]
+
+ result = super(ActionModule, self).run(task_vars=task_vars)
+
+ if persistent_connection != "network_cli":
+ # It is supported only with network_cli
+ result["failed"] = True
+ result["msg"] = (
+ "connection type %s is not valid for net_put module,"
+ " please use fully qualified name of network_cli connection type"
+ % self._play_context.connection
+ )
+ return result
+
+ try:
+ src = self._task.args["src"]
+ except KeyError as exc:
+ return {
+ "failed": True,
+ "msg": "missing required argument: %s" % exc,
+ }
+
+ src_file_path_name = src
+
+ # Get destination file if specified
+ dest = self._task.args.get("dest")
+
+ # Get proto
+ proto = self._task.args.get("protocol")
+ if proto is None:
+ proto = "scp"
+
+ # Get mode if set
+ mode = self._task.args.get("mode")
+ if mode is None:
+ mode = "binary"
+
+ if mode == "text":
+ try:
+ self._handle_template(convert_data=False)
+ except ValueError as exc:
+ return dict(failed=True, msg=to_text(exc))
+
+ # Now src has resolved file write to disk in current diectory for scp
+ src = self._task.args.get("src")
+ filename = str(uuid.uuid4())
+ cwd = self._loader.get_basedir()
+ output_file = os.path.join(cwd, filename)
+ try:
+ with open(output_file, "wb") as f:
+ f.write(to_bytes(src, encoding="utf-8"))
+ except Exception:
+ os.remove(output_file)
+ raise
+ else:
+ try:
+ output_file = self._get_binary_src_file(src)
+ except ValueError as exc:
+ return dict(failed=True, msg=to_text(exc))
+
+ if socket_path is None:
+ socket_path = self._connection.socket_path
+
+ conn = Connection(socket_path)
+ sock_timeout = conn.get_option("persistent_command_timeout")
+
+ if dest is None:
+ dest = src_file_path_name
+
+ try:
+ changed = self._handle_existing_file(
+ conn, output_file, dest, proto, sock_timeout
+ )
+ if changed is False:
+ result["changed"] = changed
+ result["destination"] = dest
+ return result
+ except Exception as exc:
+ result["msg"] = (
+ "Warning: %s idempotency check failed. Check dest" % exc
+ )
+
+ try:
+ conn.copy_file(
+ source=output_file,
+ destination=dest,
+ proto=proto,
+ timeout=sock_timeout,
+ )
+ except Exception as exc:
+ if to_text(exc) == "No response from server":
+ if network_os == "iosxr":
+ # IOSXR sometimes closes socket prematurely after completion
+ # of file transfer
+ result[
+ "msg"
+ ] = "Warning: iosxr scp server pre close issue. Please check dest"
+ else:
+ result["failed"] = True
+ result["msg"] = "Exception received: %s" % exc
+
+ if mode == "text":
+ # Cleanup tmp file expanded wih ansible vars
+ os.remove(output_file)
+
+ result["changed"] = changed
+ result["destination"] = dest
+ return result
+
+ def _handle_existing_file(self, conn, source, dest, proto, timeout):
+ """
+ Determines whether the source and destination file match.
+
+ :return: False if source and dest both exist and have matching sha1 sums, True otherwise.
+ """
+ cwd = self._loader.get_basedir()
+ filename = str(uuid.uuid4())
+ tmp_source_file = os.path.join(cwd, filename)
+ try:
+ conn.get_file(
+ source=dest,
+ destination=tmp_source_file,
+ proto=proto,
+ timeout=timeout,
+ )
+ except ConnectionError as exc:
+ error = to_text(exc)
+ if error.endswith("No such file or directory"):
+ if os.path.exists(tmp_source_file):
+ os.remove(tmp_source_file)
+ return True
+
+ try:
+ with open(source, "r") as f:
+ new_content = f.read()
+ with open(tmp_source_file, "r") as f:
+ old_content = f.read()
+ except (IOError, OSError):
+ os.remove(tmp_source_file)
+ raise
+
+ sha1 = hashlib.sha1()
+ old_content_b = to_bytes(old_content, errors="surrogate_or_strict")
+ sha1.update(old_content_b)
+ checksum_old = sha1.digest()
+
+ sha1 = hashlib.sha1()
+ new_content_b = to_bytes(new_content, errors="surrogate_or_strict")
+ sha1.update(new_content_b)
+ checksum_new = sha1.digest()
+ os.remove(tmp_source_file)
+ if checksum_old == checksum_new:
+ return False
+ return True
+
+ def _get_binary_src_file(self, src):
+ working_path = self._get_working_path()
+
+ if os.path.isabs(src) or urlsplit("src").scheme:
+ source = src
+ else:
+ source = self._loader.path_dwim_relative(
+ working_path, "templates", src
+ )
+ if not source:
+ source = self._loader.path_dwim_relative(working_path, src)
+
+ if not os.path.exists(source):
+ raise ValueError("path specified in src not found")
+
+ return source
+
+ def _get_working_path(self):
+ cwd = self._loader.get_basedir()
+ if self._task._role is not None:
+ cwd = self._task._role._role_path
+ return cwd
+
+ def _get_network_os(self, task_vars):
+ if "network_os" in self._task.args and self._task.args["network_os"]:
+ display.vvvv("Getting network OS from task argument")
+ network_os = self._task.args["network_os"]
+ elif self._play_context.network_os:
+ display.vvvv("Getting network OS from inventory")
+ network_os = self._play_context.network_os
+ elif (
+ "network_os" in task_vars.get("ansible_facts", {})
+ and task_vars["ansible_facts"]["network_os"]
+ ):
+ display.vvvv("Getting network OS from fact")
+ network_os = task_vars["ansible_facts"]["network_os"]
+ else:
+ raise AnsibleError(
+ "ansible_network_os must be specified on this host"
+ )
+
+ return network_os
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py
new file mode 100644
index 00000000..5d05d338
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py
@@ -0,0 +1,209 @@
+#
+# (c) 2018 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import os
+import time
+import re
+
+from ansible.errors import AnsibleError
+from ansible.module_utils._text import to_text, to_bytes
+from ansible.module_utils.six.moves.urllib.parse import urlsplit
+from ansible.plugins.action.normal import ActionModule as _ActionModule
+from ansible.utils.display import Display
+
+display = Display()
+
+PRIVATE_KEYS_RE = re.compile("__.+__")
+
+
+class ActionModule(_ActionModule):
+ def run(self, task_vars=None):
+ config_module = hasattr(self, "_config_module") and self._config_module
+ if config_module and self._task.args.get("src"):
+ try:
+ self._handle_src_option()
+ except AnsibleError as e:
+ return {"failed": True, "msg": e.message, "changed": False}
+
+ result = super(ActionModule, self).run(task_vars=task_vars)
+
+ if (
+ config_module
+ and self._task.args.get("backup")
+ and not result.get("failed")
+ ):
+ self._handle_backup_option(result, task_vars)
+
+ return result
+
+ def _handle_backup_option(self, result, task_vars):
+
+ filename = None
+ backup_path = None
+ try:
+ content = result["__backup__"]
+ except KeyError:
+ raise AnsibleError("Failed while reading configuration backup")
+
+ backup_options = self._task.args.get("backup_options")
+ if backup_options:
+ filename = backup_options.get("filename")
+ backup_path = backup_options.get("dir_path")
+
+ if not backup_path:
+ cwd = self._get_working_path()
+ backup_path = os.path.join(cwd, "backup")
+ if not filename:
+ tstamp = time.strftime(
+ "%Y-%m-%d@%H:%M:%S", time.localtime(time.time())
+ )
+ filename = "%s_config.%s" % (
+ task_vars["inventory_hostname"],
+ tstamp,
+ )
+
+ dest = os.path.join(backup_path, filename)
+ backup_path = os.path.expanduser(
+ os.path.expandvars(
+ to_bytes(backup_path, errors="surrogate_or_strict")
+ )
+ )
+
+ if not os.path.exists(backup_path):
+ os.makedirs(backup_path)
+
+ new_task = self._task.copy()
+ for item in self._task.args:
+ if not item.startswith("_"):
+ new_task.args.pop(item, None)
+
+ new_task.args.update(dict(content=content, dest=dest))
+ copy_action = self._shared_loader_obj.action_loader.get(
+ "copy",
+ task=new_task,
+ connection=self._connection,
+ play_context=self._play_context,
+ loader=self._loader,
+ templar=self._templar,
+ shared_loader_obj=self._shared_loader_obj,
+ )
+ copy_result = copy_action.run(task_vars=task_vars)
+ if copy_result.get("failed"):
+ result["failed"] = copy_result["failed"]
+ result["msg"] = copy_result.get("msg")
+ return
+
+ result["backup_path"] = dest
+ if copy_result.get("changed", False):
+ result["changed"] = copy_result["changed"]
+
+ if backup_options and backup_options.get("filename"):
+ result["date"] = time.strftime(
+ "%Y-%m-%d",
+ time.gmtime(os.stat(result["backup_path"]).st_ctime),
+ )
+ result["time"] = time.strftime(
+ "%H:%M:%S",
+ time.gmtime(os.stat(result["backup_path"]).st_ctime),
+ )
+
+ else:
+ result["date"] = tstamp.split("@")[0]
+ result["time"] = tstamp.split("@")[1]
+ result["shortname"] = result["backup_path"][::-1].split(".", 1)[1][
+ ::-1
+ ]
+ result["filename"] = result["backup_path"].split("/")[-1]
+
+ # strip out any keys that have two leading and two trailing
+ # underscore characters
+ for key in list(result.keys()):
+ if PRIVATE_KEYS_RE.match(key):
+ del result[key]
+
+ def _get_working_path(self):
+ cwd = self._loader.get_basedir()
+ if self._task._role is not None:
+ cwd = self._task._role._role_path
+ return cwd
+
+ def _handle_src_option(self, convert_data=True):
+ src = self._task.args.get("src")
+ working_path = self._get_working_path()
+
+ if os.path.isabs(src) or urlsplit("src").scheme:
+ source = src
+ else:
+ source = self._loader.path_dwim_relative(
+ working_path, "templates", src
+ )
+ if not source:
+ source = self._loader.path_dwim_relative(working_path, src)
+
+ if not os.path.exists(source):
+ raise AnsibleError("path specified in src not found")
+
+ try:
+ with open(source, "r") as f:
+ template_data = to_text(f.read())
+ except IOError as e:
+ raise AnsibleError(
+ "unable to load src file {0}, I/O error({1}): {2}".format(
+ source, e.errno, e.strerror
+ )
+ )
+
+ # Create a template search path in the following order:
+ # [working_path, self_role_path, dependent_role_paths, dirname(source)]
+ searchpath = [working_path]
+ if self._task._role is not None:
+ searchpath.append(self._task._role._role_path)
+ if hasattr(self._task, "_block:"):
+ dep_chain = self._task._block.get_dep_chain()
+ if dep_chain is not None:
+ for role in dep_chain:
+ searchpath.append(role._role_path)
+ searchpath.append(os.path.dirname(source))
+ with self._templar.set_temporary_context(searchpath=searchpath):
+ self._task.args["src"] = self._templar.template(
+ template_data, convert_data=convert_data
+ )
+
+ def _get_network_os(self, task_vars):
+ if "network_os" in self._task.args and self._task.args["network_os"]:
+ display.vvvv("Getting network OS from task argument")
+ network_os = self._task.args["network_os"]
+ elif self._play_context.network_os:
+ display.vvvv("Getting network OS from inventory")
+ network_os = self._play_context.network_os
+ elif (
+ "network_os" in task_vars.get("ansible_facts", {})
+ and task_vars["ansible_facts"]["network_os"]
+ ):
+ display.vvvv("Getting network OS from fact")
+ network_os = task_vars["ansible_facts"]["network_os"]
+ else:
+ raise AnsibleError(
+ "ansible_network_os must be specified on this host"
+ )
+
+ return network_os
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py
new file mode 100644
index 00000000..33938fd1
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright: (c) 2018, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+DOCUMENTATION = """become: enable
+short_description: Switch to elevated permissions on a network device
+description:
+- This become plugins allows elevated permissions on a remote network device.
+author: ansible (@core)
+options:
+ become_pass:
+ description: password
+ ini:
+ - section: enable_become_plugin
+ key: password
+ vars:
+ - name: ansible_become_password
+ - name: ansible_become_pass
+ - name: ansible_enable_pass
+ env:
+ - name: ANSIBLE_BECOME_PASS
+ - name: ANSIBLE_ENABLE_PASS
+notes:
+- enable is really implemented in the network connection handler and as such can only
+ be used with network connections.
+- This plugin ignores the 'become_exe' and 'become_user' settings as it uses an API
+ and not an executable.
+"""
+
+from ansible.plugins.become import BecomeBase
+
+
+class BecomeModule(BecomeBase):
+
+ name = "ansible.netcommon.enable"
+
+ def build_become_command(self, cmd, shell):
+ # enable is implemented inside the network connection plugins
+ return cmd
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py
new file mode 100644
index 00000000..b063ef0d
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py
@@ -0,0 +1,324 @@
+# (c) 2018 Red Hat Inc.
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+DOCUMENTATION = """author: Ansible Networking Team
+connection: httpapi
+short_description: Use httpapi to run command on network appliances
+description:
+- This connection plugin provides a connection to remote devices over a HTTP(S)-based
+ api.
+options:
+ host:
+ description:
+ - Specifies the remote device FQDN or IP address to establish the HTTP(S) connection
+ to.
+ default: inventory_hostname
+ vars:
+ - name: ansible_host
+ port:
+ type: int
+ description:
+ - Specifies the port on the remote device that listens for connections when establishing
+ the HTTP(S) connection.
+ - When unspecified, will pick 80 or 443 based on the value of use_ssl.
+ ini:
+ - section: defaults
+ key: remote_port
+ env:
+ - name: ANSIBLE_REMOTE_PORT
+ vars:
+ - name: ansible_httpapi_port
+ network_os:
+ description:
+ - Configures the device platform network operating system. This value is used
+ to load the correct httpapi plugin to communicate with the remote device
+ vars:
+ - name: ansible_network_os
+ remote_user:
+ description:
+ - The username used to authenticate to the remote device when the API connection
+ is first established. If the remote_user is not specified, the connection will
+ use the username of the logged in user.
+ - Can be configured from the CLI via the C(--user) or C(-u) options.
+ ini:
+ - section: defaults
+ key: remote_user
+ env:
+ - name: ANSIBLE_REMOTE_USER
+ vars:
+ - name: ansible_user
+ password:
+ description:
+ - Configures the user password used to authenticate to the remote device when
+ needed for the device API.
+ vars:
+ - name: ansible_password
+ - name: ansible_httpapi_pass
+ - name: ansible_httpapi_password
+ use_ssl:
+ type: boolean
+ description:
+ - Whether to connect using SSL (HTTPS) or not (HTTP).
+ default: false
+ vars:
+ - name: ansible_httpapi_use_ssl
+ validate_certs:
+ type: boolean
+ description:
+ - Whether to validate SSL certificates
+ default: true
+ vars:
+ - name: ansible_httpapi_validate_certs
+ use_proxy:
+ type: boolean
+ description:
+ - Whether to use https_proxy for requests.
+ default: true
+ vars:
+ - name: ansible_httpapi_use_proxy
+ become:
+ type: boolean
+ description:
+ - The become option will instruct the CLI session to attempt privilege escalation
+ on platforms that support it. Normally this means transitioning from user mode
+ to C(enable) mode in the CLI session. If become is set to True and the remote
+ device does not support privilege escalation or the privilege has already been
+ elevated, then this option is silently ignored.
+ - Can be configured from the CLI via the C(--become) or C(-b) options.
+ default: false
+ ini:
+ - section: privilege_escalation
+ key: become
+ env:
+ - name: ANSIBLE_BECOME
+ vars:
+ - name: ansible_become
+ become_method:
+ description:
+ - This option allows the become method to be specified in for handling privilege
+ escalation. Typically the become_method value is set to C(enable) but could
+ be defined as other values.
+ default: sudo
+ ini:
+ - section: privilege_escalation
+ key: become_method
+ env:
+ - name: ANSIBLE_BECOME_METHOD
+ vars:
+ - name: ansible_become_method
+ persistent_connect_timeout:
+ type: int
+ description:
+ - Configures, in seconds, the amount of time to wait when trying to initially
+ establish a persistent connection. If this value expires before the connection
+ to the remote device is completed, the connection will fail.
+ default: 30
+ ini:
+ - section: persistent_connection
+ key: connect_timeout
+ env:
+ - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT
+ vars:
+ - name: ansible_connect_timeout
+ persistent_command_timeout:
+ type: int
+ description:
+ - Configures, in seconds, the amount of time to wait for a command to return from
+ the remote device. If this timer is exceeded before the command returns, the
+ connection plugin will raise an exception and close.
+ default: 30
+ ini:
+ - section: persistent_connection
+ key: command_timeout
+ env:
+ - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT
+ vars:
+ - name: ansible_command_timeout
+ persistent_log_messages:
+ type: boolean
+ description:
+ - This flag will enable logging the command executed and response received from
+ target device in the ansible log file. For this option to work 'log_path' ansible
+ configuration option is required to be set to a file path with write access.
+ - Be sure to fully understand the security implications of enabling this option
+ as it could create a security vulnerability by logging sensitive information
+ in log file.
+ default: false
+ ini:
+ - section: persistent_connection
+ key: log_messages
+ env:
+ - name: ANSIBLE_PERSISTENT_LOG_MESSAGES
+ vars:
+ - name: ansible_persistent_log_messages
+"""
+
+from io import BytesIO
+
+from ansible.errors import AnsibleConnectionFailure
+from ansible.module_utils._text import to_bytes
+from ansible.module_utils.six import PY3
+from ansible.module_utils.six.moves import cPickle
+from ansible.module_utils.six.moves.urllib.error import HTTPError, URLError
+from ansible.module_utils.urls import open_url
+from ansible.playbook.play_context import PlayContext
+from ansible.plugins.loader import httpapi_loader
+from ansible.plugins.connection import NetworkConnectionBase, ensure_connect
+
+
+class Connection(NetworkConnectionBase):
+ """Network API connection"""
+
+ transport = "ansible.netcommon.httpapi"
+ has_pipelining = True
+
+ def __init__(self, play_context, new_stdin, *args, **kwargs):
+ super(Connection, self).__init__(
+ play_context, new_stdin, *args, **kwargs
+ )
+
+ self._url = None
+ self._auth = None
+
+ if self._network_os:
+
+ self.httpapi = httpapi_loader.get(self._network_os, self)
+ if self.httpapi:
+ self._sub_plugin = {
+ "type": "httpapi",
+ "name": self.httpapi._load_name,
+ "obj": self.httpapi,
+ }
+ self.queue_message(
+ "vvvv",
+ "loaded API plugin %s from path %s for network_os %s"
+ % (
+ self.httpapi._load_name,
+ self.httpapi._original_path,
+ self._network_os,
+ ),
+ )
+ else:
+ raise AnsibleConnectionFailure(
+ "unable to load API plugin for network_os %s"
+ % self._network_os
+ )
+
+ else:
+ raise AnsibleConnectionFailure(
+ "Unable to automatically determine host network os. Please "
+ "manually configure ansible_network_os value for this host"
+ )
+ self.queue_message("log", "network_os is set to %s" % self._network_os)
+
+ def update_play_context(self, pc_data):
+ """Updates the play context information for the connection"""
+ pc_data = to_bytes(pc_data)
+ if PY3:
+ pc_data = cPickle.loads(pc_data, encoding="bytes")
+ else:
+ pc_data = cPickle.loads(pc_data)
+ play_context = PlayContext()
+ play_context.deserialize(pc_data)
+
+ self.queue_message("vvvv", "updating play_context for connection")
+ if self._play_context.become ^ play_context.become:
+ self.set_become(play_context)
+ if play_context.become is True:
+ self.queue_message("vvvv", "authorizing connection")
+ else:
+ self.queue_message("vvvv", "deauthorizing connection")
+
+ self._play_context = play_context
+
+ def _connect(self):
+ if not self.connected:
+ protocol = "https" if self.get_option("use_ssl") else "http"
+ host = self.get_option("host")
+ port = self.get_option("port") or (
+ 443 if protocol == "https" else 80
+ )
+ self._url = "%s://%s:%s" % (protocol, host, port)
+
+ self.queue_message(
+ "vvv",
+ "ESTABLISH HTTP(S) CONNECTFOR USER: %s TO %s"
+ % (self._play_context.remote_user, self._url),
+ )
+ self.httpapi.set_become(self._play_context)
+ self._connected = True
+
+ self.httpapi.login(
+ self.get_option("remote_user"), self.get_option("password")
+ )
+
+ def close(self):
+ """
+ Close the active session to the device
+ """
+ # only close the connection if its connected.
+ if self._connected:
+ self.queue_message("vvvv", "closing http(s) connection to device")
+ self.logout()
+
+ super(Connection, self).close()
+
+ @ensure_connect
+ def send(self, path, data, **kwargs):
+ """
+ Sends the command to the device over api
+ """
+ url_kwargs = dict(
+ timeout=self.get_option("persistent_command_timeout"),
+ validate_certs=self.get_option("validate_certs"),
+ use_proxy=self.get_option("use_proxy"),
+ headers={},
+ )
+ url_kwargs.update(kwargs)
+ if self._auth:
+ # Avoid modifying passed-in headers
+ headers = dict(kwargs.get("headers", {}))
+ headers.update(self._auth)
+ url_kwargs["headers"] = headers
+ else:
+ url_kwargs["force_basic_auth"] = True
+ url_kwargs["url_username"] = self.get_option("remote_user")
+ url_kwargs["url_password"] = self.get_option("password")
+
+ try:
+ url = self._url + path
+ self._log_messages(
+ "send url '%s' with data '%s' and kwargs '%s'"
+ % (url, data, url_kwargs)
+ )
+ response = open_url(url, data=data, **url_kwargs)
+ except HTTPError as exc:
+ is_handled = self.handle_httperror(exc)
+ if is_handled is True:
+ return self.send(path, data, **kwargs)
+ elif is_handled is False:
+ raise
+ else:
+ response = is_handled
+ except URLError as exc:
+ raise AnsibleConnectionFailure(
+ "Could not connect to {0}: {1}".format(
+ self._url + path, exc.reason
+ )
+ )
+
+ response_buffer = BytesIO()
+ resp_data = response.read()
+ self._log_messages("received response: '%s'" % resp_data)
+ response_buffer.write(resp_data)
+
+ # Try to assign a new auth token if one is given
+ self._auth = self.update_auth(response, response_buffer) or self._auth
+
+ response_buffer.seek(0)
+
+ return response, response_buffer
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py
new file mode 100644
index 00000000..1e2d3caa
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py
@@ -0,0 +1,404 @@
+# (c) 2016 Red Hat Inc.
+# (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+DOCUMENTATION = """author: Ansible Networking Team
+connection: netconf
+short_description: Provides a persistent connection using the netconf protocol
+description:
+- This connection plugin provides a connection to remote devices over the SSH NETCONF
+ subsystem. This connection plugin is typically used by network devices for sending
+ and receiving RPC calls over NETCONF.
+- Note this connection plugin requires ncclient to be installed on the local Ansible
+ controller.
+requirements:
+- ncclient
+options:
+ host:
+ description:
+ - Specifies the remote device FQDN or IP address to establish the SSH connection
+ to.
+ default: inventory_hostname
+ vars:
+ - name: ansible_host
+ port:
+ type: int
+ description:
+ - Specifies the port on the remote device that listens for connections when establishing
+ the SSH connection.
+ default: 830
+ ini:
+ - section: defaults
+ key: remote_port
+ env:
+ - name: ANSIBLE_REMOTE_PORT
+ vars:
+ - name: ansible_port
+ network_os:
+ description:
+ - Configures the device platform network operating system. This value is used
+ to load a device specific netconf plugin. If this option is not configured
+ (or set to C(auto)), then Ansible will attempt to guess the correct network_os
+ to use. If it can not guess a network_os correctly it will use C(default).
+ vars:
+ - name: ansible_network_os
+ remote_user:
+ description:
+ - The username used to authenticate to the remote device when the SSH connection
+ is first established. If the remote_user is not specified, the connection will
+ use the username of the logged in user.
+ - Can be configured from the CLI via the C(--user) or C(-u) options.
+ ini:
+ - section: defaults
+ key: remote_user
+ env:
+ - name: ANSIBLE_REMOTE_USER
+ vars:
+ - name: ansible_user
+ password:
+ description:
+ - Configures the user password used to authenticate to the remote device when
+ first establishing the SSH connection.
+ vars:
+ - name: ansible_password
+ - name: ansible_ssh_pass
+ - name: ansible_ssh_password
+ - name: ansible_netconf_password
+ private_key_file:
+ description:
+ - The private SSH key or certificate file used to authenticate to the remote device
+ when first establishing the SSH connection.
+ ini:
+ - section: defaults
+ key: private_key_file
+ env:
+ - name: ANSIBLE_PRIVATE_KEY_FILE
+ vars:
+ - name: ansible_private_key_file
+ look_for_keys:
+ default: true
+ description:
+ - Enables looking for ssh keys in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`).
+ env:
+ - name: ANSIBLE_PARAMIKO_LOOK_FOR_KEYS
+ ini:
+ - section: paramiko_connection
+ key: look_for_keys
+ type: boolean
+ host_key_checking:
+ description: Set this to "False" if you want to avoid host key checking by the
+ underlying tools Ansible uses to connect to the host
+ type: boolean
+ default: true
+ env:
+ - name: ANSIBLE_HOST_KEY_CHECKING
+ - name: ANSIBLE_SSH_HOST_KEY_CHECKING
+ - name: ANSIBLE_NETCONF_HOST_KEY_CHECKING
+ ini:
+ - section: defaults
+ key: host_key_checking
+ - section: paramiko_connection
+ key: host_key_checking
+ vars:
+ - name: ansible_host_key_checking
+ - name: ansible_ssh_host_key_checking
+ - name: ansible_netconf_host_key_checking
+ persistent_connect_timeout:
+ type: int
+ description:
+ - Configures, in seconds, the amount of time to wait when trying to initially
+ establish a persistent connection. If this value expires before the connection
+ to the remote device is completed, the connection will fail.
+ default: 30
+ ini:
+ - section: persistent_connection
+ key: connect_timeout
+ env:
+ - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT
+ vars:
+ - name: ansible_connect_timeout
+ persistent_command_timeout:
+ type: int
+ description:
+ - Configures, in seconds, the amount of time to wait for a command to return from
+ the remote device. If this timer is exceeded before the command returns, the
+ connection plugin will raise an exception and close.
+ default: 30
+ ini:
+ - section: persistent_connection
+ key: command_timeout
+ env:
+ - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT
+ vars:
+ - name: ansible_command_timeout
+ netconf_ssh_config:
+ description:
+ - This variable is used to enable bastion/jump host with netconf connection. If
+ set to True the bastion/jump host ssh settings should be present in ~/.ssh/config
+ file, alternatively it can be set to custom ssh configuration file path to read
+ the bastion/jump host settings.
+ ini:
+ - section: netconf_connection
+ key: ssh_config
+ version_added: '2.7'
+ env:
+ - name: ANSIBLE_NETCONF_SSH_CONFIG
+ vars:
+ - name: ansible_netconf_ssh_config
+ version_added: '2.7'
+ persistent_log_messages:
+ type: boolean
+ description:
+ - This flag will enable logging the command executed and response received from
+ target device in the ansible log file. For this option to work 'log_path' ansible
+ configuration option is required to be set to a file path with write access.
+ - Be sure to fully understand the security implications of enabling this option
+ as it could create a security vulnerability by logging sensitive information
+ in log file.
+ default: false
+ ini:
+ - section: persistent_connection
+ key: log_messages
+ env:
+ - name: ANSIBLE_PERSISTENT_LOG_MESSAGES
+ vars:
+ - name: ansible_persistent_log_messages
+"""
+
+import os
+import logging
+import json
+
+from ansible.errors import AnsibleConnectionFailure, AnsibleError
+from ansible.module_utils._text import to_bytes, to_native, to_text
+from ansible.module_utils.basic import missing_required_lib
+from ansible.module_utils.parsing.convert_bool import (
+ BOOLEANS_TRUE,
+ BOOLEANS_FALSE,
+)
+from ansible.plugins.loader import netconf_loader
+from ansible.plugins.connection import NetworkConnectionBase, ensure_connect
+
+try:
+ from ncclient import manager
+ from ncclient.operations import RPCError
+ from ncclient.transport.errors import SSHUnknownHostError
+ from ncclient.xml_ import to_ele, to_xml
+
+ HAS_NCCLIENT = True
+ NCCLIENT_IMP_ERR = None
+except (
+ ImportError,
+ AttributeError,
+) as err: # paramiko and gssapi are incompatible and raise AttributeError not ImportError
+ HAS_NCCLIENT = False
+ NCCLIENT_IMP_ERR = err
+
+logging.getLogger("ncclient").setLevel(logging.INFO)
+
+
+class Connection(NetworkConnectionBase):
+ """NetConf connections"""
+
+ transport = "ansible.netcommon.netconf"
+ has_pipelining = False
+
+ def __init__(self, play_context, new_stdin, *args, **kwargs):
+ super(Connection, self).__init__(
+ play_context, new_stdin, *args, **kwargs
+ )
+
+ # If network_os is not specified then set the network os to auto
+ # This will be used to trigger the use of guess_network_os when connecting.
+ self._network_os = self._network_os or "auto"
+
+ self.netconf = netconf_loader.get(self._network_os, self)
+ if self.netconf:
+ self._sub_plugin = {
+ "type": "netconf",
+ "name": self.netconf._load_name,
+ "obj": self.netconf,
+ }
+ self.queue_message(
+ "vvvv",
+ "loaded netconf plugin %s from path %s for network_os %s"
+ % (
+ self.netconf._load_name,
+ self.netconf._original_path,
+ self._network_os,
+ ),
+ )
+ else:
+ self.netconf = netconf_loader.get("default", self)
+ self._sub_plugin = {
+ "type": "netconf",
+ "name": "default",
+ "obj": self.netconf,
+ }
+ self.queue_message(
+ "display",
+ "unable to load netconf plugin for network_os %s, falling back to default plugin"
+ % self._network_os,
+ )
+
+ self.queue_message("log", "network_os is set to %s" % self._network_os)
+ self._manager = None
+ self.key_filename = None
+ self._ssh_config = None
+
+ def exec_command(self, cmd, in_data=None, sudoable=True):
+ """Sends the request to the node and returns the reply
+ The method accepts two forms of request. The first form is as a byte
+ string that represents xml string be send over netconf session.
+ The second form is a json-rpc (2.0) byte string.
+ """
+ if self._manager:
+ # to_ele operates on native strings
+ request = to_ele(to_native(cmd, errors="surrogate_or_strict"))
+
+ if request is None:
+ return "unable to parse request"
+
+ try:
+ reply = self._manager.rpc(request)
+ except RPCError as exc:
+ error = self.internal_error(
+ data=to_text(to_xml(exc.xml), errors="surrogate_or_strict")
+ )
+ return json.dumps(error)
+
+ return reply.data_xml
+ else:
+ return super(Connection, self).exec_command(cmd, in_data, sudoable)
+
+ @property
+ @ensure_connect
+ def manager(self):
+ return self._manager
+
+ def _connect(self):
+ if not HAS_NCCLIENT:
+ raise AnsibleError(
+ "%s: %s"
+ % (
+ missing_required_lib("ncclient"),
+ to_native(NCCLIENT_IMP_ERR),
+ )
+ )
+
+ self.queue_message("log", "ssh connection done, starting ncclient")
+
+ allow_agent = True
+ if self._play_context.password is not None:
+ allow_agent = False
+ setattr(self._play_context, "allow_agent", allow_agent)
+
+ self.key_filename = (
+ self._play_context.private_key_file
+ or self.get_option("private_key_file")
+ )
+ if self.key_filename:
+ self.key_filename = str(os.path.expanduser(self.key_filename))
+
+ self._ssh_config = self.get_option("netconf_ssh_config")
+ if self._ssh_config in BOOLEANS_TRUE:
+ self._ssh_config = True
+ elif self._ssh_config in BOOLEANS_FALSE:
+ self._ssh_config = None
+
+ # Try to guess the network_os if the network_os is set to auto
+ if self._network_os == "auto":
+ for cls in netconf_loader.all(class_only=True):
+ network_os = cls.guess_network_os(self)
+ if network_os:
+ self.queue_message(
+ "vvv", "discovered network_os %s" % network_os
+ )
+ self._network_os = network_os
+
+ # If we have tried to detect the network_os but were unable to i.e. network_os is still 'auto'
+ # then use default as the network_os
+
+ if self._network_os == "auto":
+ # Network os not discovered. Set it to default
+ self.queue_message(
+ "vvv",
+ "Unable to discover network_os. Falling back to default.",
+ )
+ self._network_os = "default"
+ try:
+ ncclient_device_handler = self.netconf.get_option(
+ "ncclient_device_handler"
+ )
+ except KeyError:
+ ncclient_device_handler = "default"
+ self.queue_message(
+ "vvv",
+ "identified ncclient device handler: %s."
+ % ncclient_device_handler,
+ )
+ device_params = {"name": ncclient_device_handler}
+
+ try:
+ port = self._play_context.port or 830
+ self.queue_message(
+ "vvv",
+ "ESTABLISH NETCONF SSH CONNECTION FOR USER: %s on PORT %s TO %s WITH SSH_CONFIG = %s"
+ % (
+ self._play_context.remote_user,
+ port,
+ self._play_context.remote_addr,
+ self._ssh_config,
+ ),
+ )
+ self._manager = manager.connect(
+ host=self._play_context.remote_addr,
+ port=port,
+ username=self._play_context.remote_user,
+ password=self._play_context.password,
+ key_filename=self.key_filename,
+ hostkey_verify=self.get_option("host_key_checking"),
+ look_for_keys=self.get_option("look_for_keys"),
+ device_params=device_params,
+ allow_agent=self._play_context.allow_agent,
+ timeout=self.get_option("persistent_connect_timeout"),
+ ssh_config=self._ssh_config,
+ )
+
+ self._manager._timeout = self.get_option(
+ "persistent_command_timeout"
+ )
+ except SSHUnknownHostError as exc:
+ raise AnsibleConnectionFailure(to_native(exc))
+ except ImportError:
+ raise AnsibleError(
+ "connection=netconf is not supported on {0}".format(
+ self._network_os
+ )
+ )
+
+ if not self._manager.connected:
+ return 1, b"", b"not connected"
+
+ self.queue_message(
+ "log", "ncclient manager object created successfully"
+ )
+
+ self._connected = True
+
+ super(Connection, self)._connect()
+
+ return (
+ 0,
+ to_bytes(self._manager.session_id, errors="surrogate_or_strict"),
+ b"",
+ )
+
+ def close(self):
+ if self._manager:
+ self._manager.close_session()
+ super(Connection, self).close()
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py
new file mode 100644
index 00000000..8abcf8e8
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py
@@ -0,0 +1,924 @@
+# (c) 2016 Red Hat Inc.
+# (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+DOCUMENTATION = """author: Ansible Networking Team
+connection: network_cli
+short_description: Use network_cli to run command on network appliances
+description:
+- This connection plugin provides a connection to remote devices over the SSH and
+ implements a CLI shell. This connection plugin is typically used by network devices
+ for sending and receiving CLi commands to network devices.
+options:
+ host:
+ description:
+ - Specifies the remote device FQDN or IP address to establish the SSH connection
+ to.
+ default: inventory_hostname
+ vars:
+ - name: ansible_host
+ port:
+ type: int
+ description:
+ - Specifies the port on the remote device that listens for connections when establishing
+ the SSH connection.
+ default: 22
+ ini:
+ - section: defaults
+ key: remote_port
+ env:
+ - name: ANSIBLE_REMOTE_PORT
+ vars:
+ - name: ansible_port
+ network_os:
+ description:
+ - Configures the device platform network operating system. This value is used
+ to load the correct terminal and cliconf plugins to communicate with the remote
+ device.
+ vars:
+ - name: ansible_network_os
+ remote_user:
+ description:
+ - The username used to authenticate to the remote device when the SSH connection
+ is first established. If the remote_user is not specified, the connection will
+ use the username of the logged in user.
+ - Can be configured from the CLI via the C(--user) or C(-u) options.
+ ini:
+ - section: defaults
+ key: remote_user
+ env:
+ - name: ANSIBLE_REMOTE_USER
+ vars:
+ - name: ansible_user
+ password:
+ description:
+ - Configures the user password used to authenticate to the remote device when
+ first establishing the SSH connection.
+ vars:
+ - name: ansible_password
+ - name: ansible_ssh_pass
+ - name: ansible_ssh_password
+ private_key_file:
+ description:
+ - The private SSH key or certificate file used to authenticate to the remote device
+ when first establishing the SSH connection.
+ ini:
+ - section: defaults
+ key: private_key_file
+ env:
+ - name: ANSIBLE_PRIVATE_KEY_FILE
+ vars:
+ - name: ansible_private_key_file
+ become:
+ type: boolean
+ description:
+ - The become option will instruct the CLI session to attempt privilege escalation
+ on platforms that support it. Normally this means transitioning from user mode
+ to C(enable) mode in the CLI session. If become is set to True and the remote
+ device does not support privilege escalation or the privilege has already been
+ elevated, then this option is silently ignored.
+ - Can be configured from the CLI via the C(--become) or C(-b) options.
+ default: false
+ ini:
+ - section: privilege_escalation
+ key: become
+ env:
+ - name: ANSIBLE_BECOME
+ vars:
+ - name: ansible_become
+ become_method:
+ description:
+ - This option allows the become method to be specified in for handling privilege
+ escalation. Typically the become_method value is set to C(enable) but could
+ be defined as other values.
+ default: sudo
+ ini:
+ - section: privilege_escalation
+ key: become_method
+ env:
+ - name: ANSIBLE_BECOME_METHOD
+ vars:
+ - name: ansible_become_method
+ host_key_auto_add:
+ type: boolean
+ description:
+ - By default, Ansible will prompt the user before adding SSH keys to the known
+ hosts file. Since persistent connections such as network_cli run in background
+ processes, the user will never be prompted. By enabling this option, unknown
+ host keys will automatically be added to the known hosts file.
+ - Be sure to fully understand the security implications of enabling this option
+ on production systems as it could create a security vulnerability.
+ default: false
+ ini:
+ - section: paramiko_connection
+ key: host_key_auto_add
+ env:
+ - name: ANSIBLE_HOST_KEY_AUTO_ADD
+ persistent_connect_timeout:
+ type: int
+ description:
+ - Configures, in seconds, the amount of time to wait when trying to initially
+ establish a persistent connection. If this value expires before the connection
+ to the remote device is completed, the connection will fail.
+ default: 30
+ ini:
+ - section: persistent_connection
+ key: connect_timeout
+ env:
+ - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT
+ vars:
+ - name: ansible_connect_timeout
+ persistent_command_timeout:
+ type: int
+ description:
+ - Configures, in seconds, the amount of time to wait for a command to return from
+ the remote device. If this timer is exceeded before the command returns, the
+ connection plugin will raise an exception and close.
+ default: 30
+ ini:
+ - section: persistent_connection
+ key: command_timeout
+ env:
+ - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT
+ vars:
+ - name: ansible_command_timeout
+ persistent_buffer_read_timeout:
+ type: float
+ description:
+ - Configures, in seconds, the amount of time to wait for the data to be read from
+ Paramiko channel after the command prompt is matched. This timeout value ensures
+ that command prompt matched is correct and there is no more data left to be
+ received from remote host.
+ default: 0.1
+ ini:
+ - section: persistent_connection
+ key: buffer_read_timeout
+ env:
+ - name: ANSIBLE_PERSISTENT_BUFFER_READ_TIMEOUT
+ vars:
+ - name: ansible_buffer_read_timeout
+ persistent_log_messages:
+ type: boolean
+ description:
+ - This flag will enable logging the command executed and response received from
+ target device in the ansible log file. For this option to work 'log_path' ansible
+ configuration option is required to be set to a file path with write access.
+ - Be sure to fully understand the security implications of enabling this option
+ as it could create a security vulnerability by logging sensitive information
+ in log file.
+ default: false
+ ini:
+ - section: persistent_connection
+ key: log_messages
+ env:
+ - name: ANSIBLE_PERSISTENT_LOG_MESSAGES
+ vars:
+ - name: ansible_persistent_log_messages
+ terminal_stdout_re:
+ type: list
+ elements: dict
+ description:
+ - A single regex pattern or a sequence of patterns along with optional flags to
+ match the command prompt from the received response chunk. This option accepts
+ C(pattern) and C(flags) keys. The value of C(pattern) is a python regex pattern
+ to match the response and the value of C(flags) is the value accepted by I(flags)
+ argument of I(re.compile) python method to control the way regex is matched
+ with the response, for example I('re.I').
+ vars:
+ - name: ansible_terminal_stdout_re
+ terminal_stderr_re:
+ type: list
+ elements: dict
+ description:
+ - This option provides the regex pattern and optional flags to match the error
+ string from the received response chunk. This option accepts C(pattern) and
+ C(flags) keys. The value of C(pattern) is a python regex pattern to match the
+ response and the value of C(flags) is the value accepted by I(flags) argument
+ of I(re.compile) python method to control the way regex is matched with the
+ response, for example I('re.I').
+ vars:
+ - name: ansible_terminal_stderr_re
+ terminal_initial_prompt:
+ type: list
+ description:
+ - A single regex pattern or a sequence of patterns to evaluate the expected prompt
+ at the time of initial login to the remote host.
+ vars:
+ - name: ansible_terminal_initial_prompt
+ terminal_initial_answer:
+ type: list
+ description:
+ - The answer to reply with if the C(terminal_initial_prompt) is matched. The value
+ can be a single answer or a list of answers for multiple terminal_initial_prompt.
+ In case the login menu has multiple prompts the sequence of the prompt and excepted
+ answer should be in same order and the value of I(terminal_prompt_checkall)
+ should be set to I(True) if all the values in C(terminal_initial_prompt) are
+ expected to be matched and set to I(False) if any one login prompt is to be
+ matched.
+ vars:
+ - name: ansible_terminal_initial_answer
+ terminal_initial_prompt_checkall:
+ type: boolean
+ description:
+ - By default the value is set to I(False) and any one of the prompts mentioned
+ in C(terminal_initial_prompt) option is matched it won't check for other prompts.
+ When set to I(True) it will check for all the prompts mentioned in C(terminal_initial_prompt)
+ option in the given order and all the prompts should be received from remote
+ host if not it will result in timeout.
+ default: false
+ vars:
+ - name: ansible_terminal_initial_prompt_checkall
+ terminal_inital_prompt_newline:
+ type: boolean
+ description:
+ - This boolean flag, that when set to I(True) will send newline in the response
+ if any of values in I(terminal_initial_prompt) is matched.
+ default: true
+ vars:
+ - name: ansible_terminal_initial_prompt_newline
+ network_cli_retries:
+ description:
+ - Number of attempts to connect to remote host. The delay time between the retires
+ increases after every attempt by power of 2 in seconds till either the maximum
+ attempts are exhausted or any of the C(persistent_command_timeout) or C(persistent_connect_timeout)
+ timers are triggered.
+ default: 3
+ type: integer
+ env:
+ - name: ANSIBLE_NETWORK_CLI_RETRIES
+ ini:
+ - section: persistent_connection
+ key: network_cli_retries
+ vars:
+ - name: ansible_network_cli_retries
+"""
+
+from functools import wraps
+import getpass
+import json
+import logging
+import re
+import os
+import signal
+import socket
+import time
+import traceback
+from io import BytesIO
+
+from ansible.errors import AnsibleConnectionFailure
+from ansible.module_utils.six import PY3
+from ansible.module_utils.six.moves import cPickle
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ to_list,
+)
+from ansible.module_utils._text import to_bytes, to_text
+from ansible.playbook.play_context import PlayContext
+from ansible.plugins.connection import NetworkConnectionBase
+from ansible.plugins.loader import (
+ cliconf_loader,
+ terminal_loader,
+ connection_loader,
+)
+
+
+def ensure_connect(func):
+ @wraps(func)
+ def wrapped(self, *args, **kwargs):
+ if not self._connected:
+ self._connect()
+ self.update_cli_prompt_context()
+ return func(self, *args, **kwargs)
+
+ return wrapped
+
+
+class AnsibleCmdRespRecv(Exception):
+ pass
+
+
+class Connection(NetworkConnectionBase):
+ """ CLI (shell) SSH connections on Paramiko """
+
+ transport = "ansible.netcommon.network_cli"
+ has_pipelining = True
+
+ def __init__(self, play_context, new_stdin, *args, **kwargs):
+ super(Connection, self).__init__(
+ play_context, new_stdin, *args, **kwargs
+ )
+ self._ssh_shell = None
+
+ self._matched_prompt = None
+ self._matched_cmd_prompt = None
+ self._matched_pattern = None
+ self._last_response = None
+ self._history = list()
+ self._command_response = None
+ self._last_recv_window = None
+
+ self._terminal = None
+ self.cliconf = None
+ self._paramiko_conn = None
+
+ # Managing prompt context
+ self._check_prompt = False
+ self._task_uuid = to_text(kwargs.get("task_uuid", ""))
+
+ if self._play_context.verbosity > 3:
+ logging.getLogger("paramiko").setLevel(logging.DEBUG)
+
+ if self._network_os:
+ self._terminal = terminal_loader.get(self._network_os, self)
+ if not self._terminal:
+ raise AnsibleConnectionFailure(
+ "network os %s is not supported" % self._network_os
+ )
+
+ self.cliconf = cliconf_loader.get(self._network_os, self)
+ if self.cliconf:
+ self._sub_plugin = {
+ "type": "cliconf",
+ "name": self.cliconf._load_name,
+ "obj": self.cliconf,
+ }
+ self.queue_message(
+ "vvvv",
+ "loaded cliconf plugin %s from path %s for network_os %s"
+ % (
+ self.cliconf._load_name,
+ self.cliconf._original_path,
+ self._network_os,
+ ),
+ )
+ else:
+ self.queue_message(
+ "vvvv",
+ "unable to load cliconf for network_os %s"
+ % self._network_os,
+ )
+ else:
+ raise AnsibleConnectionFailure(
+ "Unable to automatically determine host network os. Please "
+ "manually configure ansible_network_os value for this host"
+ )
+ self.queue_message("log", "network_os is set to %s" % self._network_os)
+
+ @property
+ def paramiko_conn(self):
+ if self._paramiko_conn is None:
+ self._paramiko_conn = connection_loader.get(
+ "paramiko", self._play_context, "/dev/null"
+ )
+ self._paramiko_conn.set_options(
+ direct={
+ "look_for_keys": not bool(
+ self._play_context.password
+ and not self._play_context.private_key_file
+ )
+ }
+ )
+ return self._paramiko_conn
+
+ def _get_log_channel(self):
+ name = "p=%s u=%s | " % (os.getpid(), getpass.getuser())
+ name += "paramiko [%s]" % self._play_context.remote_addr
+ return name
+
+ @ensure_connect
+ def get_prompt(self):
+ """Returns the current prompt from the device"""
+ return self._matched_prompt
+
+ def exec_command(self, cmd, in_data=None, sudoable=True):
+ # this try..except block is just to handle the transition to supporting
+ # network_cli as a toplevel connection. Once connection=local is gone,
+ # this block can be removed as well and all calls passed directly to
+ # the local connection
+ if self._ssh_shell:
+ try:
+ cmd = json.loads(to_text(cmd, errors="surrogate_or_strict"))
+ kwargs = {
+ "command": to_bytes(
+ cmd["command"], errors="surrogate_or_strict"
+ )
+ }
+ for key in (
+ "prompt",
+ "answer",
+ "sendonly",
+ "newline",
+ "prompt_retry_check",
+ ):
+ if cmd.get(key) is True or cmd.get(key) is False:
+ kwargs[key] = cmd[key]
+ elif cmd.get(key) is not None:
+ kwargs[key] = to_bytes(
+ cmd[key], errors="surrogate_or_strict"
+ )
+ return self.send(**kwargs)
+ except ValueError:
+ cmd = to_bytes(cmd, errors="surrogate_or_strict")
+ return self.send(command=cmd)
+
+ else:
+ return super(Connection, self).exec_command(cmd, in_data, sudoable)
+
+ def update_play_context(self, pc_data):
+ """Updates the play context information for the connection"""
+ pc_data = to_bytes(pc_data)
+ if PY3:
+ pc_data = cPickle.loads(pc_data, encoding="bytes")
+ else:
+ pc_data = cPickle.loads(pc_data)
+ play_context = PlayContext()
+ play_context.deserialize(pc_data)
+
+ self.queue_message("vvvv", "updating play_context for connection")
+ if self._play_context.become ^ play_context.become:
+ if play_context.become is True:
+ auth_pass = play_context.become_pass
+ self._terminal.on_become(passwd=auth_pass)
+ self.queue_message("vvvv", "authorizing connection")
+ else:
+ self._terminal.on_unbecome()
+ self.queue_message("vvvv", "deauthorizing connection")
+
+ self._play_context = play_context
+
+ if hasattr(self, "reset_history"):
+ self.reset_history()
+ if hasattr(self, "disable_response_logging"):
+ self.disable_response_logging()
+
+ def set_check_prompt(self, task_uuid):
+ self._check_prompt = task_uuid
+
+ def update_cli_prompt_context(self):
+ # set cli prompt context at the start of new task run only
+ if self._check_prompt and self._task_uuid != self._check_prompt:
+ self._task_uuid, self._check_prompt = self._check_prompt, False
+ self.set_cli_prompt_context()
+
+ def _connect(self):
+ """
+ Connects to the remote device and starts the terminal
+ """
+ if not self.connected:
+ self.paramiko_conn._set_log_channel(self._get_log_channel())
+ self.paramiko_conn.force_persistence = self.force_persistence
+
+ command_timeout = self.get_option("persistent_command_timeout")
+ max_pause = min(
+ [
+ self.get_option("persistent_connect_timeout"),
+ command_timeout,
+ ]
+ )
+ retries = self.get_option("network_cli_retries")
+ total_pause = 0
+
+ for attempt in range(retries + 1):
+ try:
+ ssh = self.paramiko_conn._connect()
+ break
+ except Exception as e:
+ pause = 2 ** (attempt + 1)
+ if attempt == retries or total_pause >= max_pause:
+ raise AnsibleConnectionFailure(
+ to_text(e, errors="surrogate_or_strict")
+ )
+ else:
+ msg = (
+ u"network_cli_retry: attempt: %d, caught exception(%s), "
+ u"pausing for %d seconds"
+ % (
+ attempt + 1,
+ to_text(e, errors="surrogate_or_strict"),
+ pause,
+ )
+ )
+
+ self.queue_message("vv", msg)
+ time.sleep(pause)
+ total_pause += pause
+ continue
+
+ self.queue_message("vvvv", "ssh connection done, setting terminal")
+ self._connected = True
+
+ self._ssh_shell = ssh.ssh.invoke_shell()
+ self._ssh_shell.settimeout(command_timeout)
+
+ self.queue_message(
+ "vvvv",
+ "loaded terminal plugin for network_os %s" % self._network_os,
+ )
+
+ terminal_initial_prompt = (
+ self.get_option("terminal_initial_prompt")
+ or self._terminal.terminal_initial_prompt
+ )
+ terminal_initial_answer = (
+ self.get_option("terminal_initial_answer")
+ or self._terminal.terminal_initial_answer
+ )
+ newline = (
+ self.get_option("terminal_inital_prompt_newline")
+ or self._terminal.terminal_inital_prompt_newline
+ )
+ check_all = (
+ self.get_option("terminal_initial_prompt_checkall") or False
+ )
+
+ self.receive(
+ prompts=terminal_initial_prompt,
+ answer=terminal_initial_answer,
+ newline=newline,
+ check_all=check_all,
+ )
+
+ if self._play_context.become:
+ self.queue_message("vvvv", "firing event: on_become")
+ auth_pass = self._play_context.become_pass
+ self._terminal.on_become(passwd=auth_pass)
+
+ self.queue_message("vvvv", "firing event: on_open_shell()")
+ self._terminal.on_open_shell()
+
+ self.queue_message(
+ "vvvv", "ssh connection has completed successfully"
+ )
+
+ return self
+
+ def close(self):
+ """
+ Close the active connection to the device
+ """
+ # only close the connection if its connected.
+ if self._connected:
+ self.queue_message("debug", "closing ssh connection to device")
+ if self._ssh_shell:
+ self.queue_message("debug", "firing event: on_close_shell()")
+ self._terminal.on_close_shell()
+ self._ssh_shell.close()
+ self._ssh_shell = None
+ self.queue_message("debug", "cli session is now closed")
+
+ self.paramiko_conn.close()
+ self._paramiko_conn = None
+ self.queue_message(
+ "debug", "ssh connection has been closed successfully"
+ )
+ super(Connection, self).close()
+
+ def receive(
+ self,
+ command=None,
+ prompts=None,
+ answer=None,
+ newline=True,
+ prompt_retry_check=False,
+ check_all=False,
+ ):
+ """
+ Handles receiving of output from command
+ """
+ self._matched_prompt = None
+ self._matched_cmd_prompt = None
+ recv = BytesIO()
+ handled = False
+ command_prompt_matched = False
+ matched_prompt_window = window_count = 0
+
+ # set terminal regex values for command prompt and errors in response
+ self._terminal_stderr_re = self._get_terminal_std_re(
+ "terminal_stderr_re"
+ )
+ self._terminal_stdout_re = self._get_terminal_std_re(
+ "terminal_stdout_re"
+ )
+
+ cache_socket_timeout = self._ssh_shell.gettimeout()
+ command_timeout = self.get_option("persistent_command_timeout")
+ self._validate_timeout_value(
+ command_timeout, "persistent_command_timeout"
+ )
+ if cache_socket_timeout != command_timeout:
+ self._ssh_shell.settimeout(command_timeout)
+
+ buffer_read_timeout = self.get_option("persistent_buffer_read_timeout")
+ self._validate_timeout_value(
+ buffer_read_timeout, "persistent_buffer_read_timeout"
+ )
+
+ self._log_messages("command: %s" % command)
+ while True:
+ if command_prompt_matched:
+ try:
+ signal.signal(
+ signal.SIGALRM, self._handle_buffer_read_timeout
+ )
+ signal.setitimer(signal.ITIMER_REAL, buffer_read_timeout)
+ data = self._ssh_shell.recv(256)
+ signal.alarm(0)
+ self._log_messages(
+ "response-%s: %s" % (window_count + 1, data)
+ )
+ # if data is still received on channel it indicates the prompt string
+ # is wrongly matched in between response chunks, continue to read
+ # remaining response.
+ command_prompt_matched = False
+
+ # restart command_timeout timer
+ signal.signal(signal.SIGALRM, self._handle_command_timeout)
+ signal.alarm(command_timeout)
+
+ except AnsibleCmdRespRecv:
+ # reset socket timeout to global timeout
+ self._ssh_shell.settimeout(cache_socket_timeout)
+ return self._command_response
+ else:
+ data = self._ssh_shell.recv(256)
+ self._log_messages(
+ "response-%s: %s" % (window_count + 1, data)
+ )
+ # when a channel stream is closed, received data will be empty
+ if not data:
+ break
+
+ recv.write(data)
+ offset = recv.tell() - 256 if recv.tell() > 256 else 0
+ recv.seek(offset)
+
+ window = self._strip(recv.read())
+ self._last_recv_window = window
+ window_count += 1
+
+ if prompts and not handled:
+ handled = self._handle_prompt(
+ window, prompts, answer, newline, False, check_all
+ )
+ matched_prompt_window = window_count
+ elif (
+ prompts
+ and handled
+ and prompt_retry_check
+ and matched_prompt_window + 1 == window_count
+ ):
+ # check again even when handled, if same prompt repeats in next window
+ # (like in the case of a wrong enable password, etc) indicates
+ # value of answer is wrong, report this as error.
+ if self._handle_prompt(
+ window,
+ prompts,
+ answer,
+ newline,
+ prompt_retry_check,
+ check_all,
+ ):
+ raise AnsibleConnectionFailure(
+ "For matched prompt '%s', answer is not valid"
+ % self._matched_cmd_prompt
+ )
+
+ if self._find_prompt(window):
+ self._last_response = recv.getvalue()
+ resp = self._strip(self._last_response)
+ self._command_response = self._sanitize(resp, command)
+ if buffer_read_timeout == 0.0:
+ # reset socket timeout to global timeout
+ self._ssh_shell.settimeout(cache_socket_timeout)
+ return self._command_response
+ else:
+ command_prompt_matched = True
+
+ @ensure_connect
+ def send(
+ self,
+ command,
+ prompt=None,
+ answer=None,
+ newline=True,
+ sendonly=False,
+ prompt_retry_check=False,
+ check_all=False,
+ ):
+ """
+ Sends the command to the device in the opened shell
+ """
+ if check_all:
+ prompt_len = len(to_list(prompt))
+ answer_len = len(to_list(answer))
+ if prompt_len != answer_len:
+ raise AnsibleConnectionFailure(
+ "Number of prompts (%s) is not same as that of answers (%s)"
+ % (prompt_len, answer_len)
+ )
+ try:
+ cmd = b"%s\r" % command
+ self._history.append(cmd)
+ self._ssh_shell.sendall(cmd)
+ self._log_messages("send command: %s" % cmd)
+ if sendonly:
+ return
+ response = self.receive(
+ command, prompt, answer, newline, prompt_retry_check, check_all
+ )
+ return to_text(response, errors="surrogate_then_replace")
+ except (socket.timeout, AttributeError):
+ self.queue_message("error", traceback.format_exc())
+ raise AnsibleConnectionFailure(
+ "timeout value %s seconds reached while trying to send command: %s"
+ % (self._ssh_shell.gettimeout(), command.strip())
+ )
+
+ def _handle_buffer_read_timeout(self, signum, frame):
+ self.queue_message(
+ "vvvv",
+ "Response received, triggered 'persistent_buffer_read_timeout' timer of %s seconds"
+ % self.get_option("persistent_buffer_read_timeout"),
+ )
+ raise AnsibleCmdRespRecv()
+
+ def _handle_command_timeout(self, signum, frame):
+ msg = (
+ "command timeout triggered, timeout value is %s secs.\nSee the timeout setting options in the Network Debug and Troubleshooting Guide."
+ % self.get_option("persistent_command_timeout")
+ )
+ self.queue_message("log", msg)
+ raise AnsibleConnectionFailure(msg)
+
+ def _strip(self, data):
+ """
+ Removes ANSI codes from device response
+ """
+ for regex in self._terminal.ansi_re:
+ data = regex.sub(b"", data)
+ return data
+
+ def _handle_prompt(
+ self,
+ resp,
+ prompts,
+ answer,
+ newline,
+ prompt_retry_check=False,
+ check_all=False,
+ ):
+ """
+ Matches the command prompt and responds
+
+ :arg resp: Byte string containing the raw response from the remote
+ :arg prompts: Sequence of byte strings that we consider prompts for input
+ :arg answer: Sequence of Byte string to send back to the remote if we find a prompt.
+ A carriage return is automatically appended to this string.
+ :param prompt_retry_check: Bool value for trying to detect more prompts
+ :param check_all: Bool value to indicate if all the values in prompt sequence should be matched or any one of
+ given prompt.
+ :returns: True if a prompt was found in ``resp``. If check_all is True
+ will True only after all the prompt in the prompts list are matched. False otherwise.
+ """
+ single_prompt = False
+ if not isinstance(prompts, list):
+ prompts = [prompts]
+ single_prompt = True
+ if not isinstance(answer, list):
+ answer = [answer]
+ prompts_regex = [re.compile(to_bytes(r), re.I) for r in prompts]
+ for index, regex in enumerate(prompts_regex):
+ match = regex.search(resp)
+ if match:
+ self._matched_cmd_prompt = match.group()
+ self._log_messages(
+ "matched command prompt: %s" % self._matched_cmd_prompt
+ )
+
+ # if prompt_retry_check is enabled to check if same prompt is
+ # repeated don't send answer again.
+ if not prompt_retry_check:
+ prompt_answer = (
+ answer[index] if len(answer) > index else answer[0]
+ )
+ self._ssh_shell.sendall(b"%s" % prompt_answer)
+ if newline:
+ self._ssh_shell.sendall(b"\r")
+ prompt_answer += b"\r"
+ self._log_messages(
+ "matched command prompt answer: %s" % prompt_answer
+ )
+ if check_all and prompts and not single_prompt:
+ prompts.pop(0)
+ answer.pop(0)
+ return False
+ return True
+ return False
+
+ def _sanitize(self, resp, command=None):
+ """
+ Removes elements from the response before returning to the caller
+ """
+ cleaned = []
+ for line in resp.splitlines():
+ if command and line.strip() == command.strip():
+ continue
+
+ for prompt in self._matched_prompt.strip().splitlines():
+ if prompt.strip() in line:
+ break
+ else:
+ cleaned.append(line)
+ return b"\n".join(cleaned).strip()
+
+ def _find_prompt(self, response):
+ """Searches the buffered response for a matching command prompt
+ """
+ errored_response = None
+ is_error_message = False
+
+ for regex in self._terminal_stderr_re:
+ if regex.search(response):
+ is_error_message = True
+
+ # Check if error response ends with command prompt if not
+ # receive it buffered prompt
+ for regex in self._terminal_stdout_re:
+ match = regex.search(response)
+ if match:
+ errored_response = response
+ self._matched_pattern = regex.pattern
+ self._matched_prompt = match.group()
+ self._log_messages(
+ "matched error regex '%s' from response '%s'"
+ % (self._matched_pattern, errored_response)
+ )
+ break
+
+ if not is_error_message:
+ for regex in self._terminal_stdout_re:
+ match = regex.search(response)
+ if match:
+ self._matched_pattern = regex.pattern
+ self._matched_prompt = match.group()
+ self._log_messages(
+ "matched cli prompt '%s' with regex '%s' from response '%s'"
+ % (
+ self._matched_prompt,
+ self._matched_pattern,
+ response,
+ )
+ )
+ if not errored_response:
+ return True
+
+ if errored_response:
+ raise AnsibleConnectionFailure(errored_response)
+
+ return False
+
+ def _validate_timeout_value(self, timeout, timer_name):
+ if timeout < 0:
+ raise AnsibleConnectionFailure(
+ "'%s' timer value '%s' is invalid, value should be greater than or equal to zero."
+ % (timer_name, timeout)
+ )
+
+ def transport_test(self, connect_timeout):
+ """This method enables wait_for_connection to work.
+
+ As it is used by wait_for_connection, it is called by that module's action plugin,
+ which is on the controller process, which means that nothing done on this instance
+ should impact the actual persistent connection... this check is for informational
+ purposes only and should be properly cleaned up.
+ """
+
+ # Force a fresh connect if for some reason we have connected before.
+ self.close()
+ self._connect()
+ self.close()
+
+ def _get_terminal_std_re(self, option):
+ terminal_std_option = self.get_option(option)
+ terminal_std_re = []
+
+ if terminal_std_option:
+ for item in terminal_std_option:
+ if "pattern" not in item:
+ raise AnsibleConnectionFailure(
+ "'pattern' is a required key for option '%s',"
+ " received option value is %s" % (option, item)
+ )
+ pattern = br"%s" % to_bytes(item["pattern"])
+ flag = item.get("flags", 0)
+ if flag:
+ flag = getattr(re, flag.split(".")[1])
+ terminal_std_re.append(re.compile(pattern, flag))
+ else:
+ # To maintain backward compatibility
+ terminal_std_re = getattr(self._terminal, option)
+
+ return terminal_std_re
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py
new file mode 100644
index 00000000..b29b4872
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py
@@ -0,0 +1,97 @@
+# 2017 Red Hat Inc.
+# (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+DOCUMENTATION = """author: Ansible Core Team
+connection: persistent
+short_description: Use a persistent unix socket for connection
+description:
+- This is a helper plugin to allow making other connections persistent.
+options:
+ persistent_command_timeout:
+ type: int
+ description:
+ - Configures, in seconds, the amount of time to wait for a command to return from
+ the remote device. If this timer is exceeded before the command returns, the
+ connection plugin will raise an exception and close
+ default: 10
+ ini:
+ - section: persistent_connection
+ key: command_timeout
+ env:
+ - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT
+ vars:
+ - name: ansible_command_timeout
+"""
+from ansible.executor.task_executor import start_connection
+from ansible.plugins.connection import ConnectionBase
+from ansible.module_utils._text import to_text
+from ansible.module_utils.connection import Connection as SocketConnection
+from ansible.utils.display import Display
+
+display = Display()
+
+
+class Connection(ConnectionBase):
+ """ Local based connections """
+
+ transport = "ansible.netcommon.persistent"
+ has_pipelining = False
+
+ def __init__(self, play_context, new_stdin, *args, **kwargs):
+ super(Connection, self).__init__(
+ play_context, new_stdin, *args, **kwargs
+ )
+ self._task_uuid = to_text(kwargs.get("task_uuid", ""))
+
+ def _connect(self):
+ self._connected = True
+ return self
+
+ def exec_command(self, cmd, in_data=None, sudoable=True):
+ display.vvvv(
+ "exec_command(), socket_path=%s" % self.socket_path,
+ host=self._play_context.remote_addr,
+ )
+ connection = SocketConnection(self.socket_path)
+ out = connection.exec_command(cmd, in_data=in_data, sudoable=sudoable)
+ return 0, out, ""
+
+ def put_file(self, in_path, out_path):
+ pass
+
+ def fetch_file(self, in_path, out_path):
+ pass
+
+ def close(self):
+ self._connected = False
+
+ def run(self):
+ """Returns the path of the persistent connection socket.
+
+ Attempts to ensure (within playcontext.timeout seconds) that the
+ socket path exists. If the path exists (or the timeout has expired),
+ returns the socket path.
+ """
+ display.vvvv(
+ "starting connection from persistent connection plugin",
+ host=self._play_context.remote_addr,
+ )
+ variables = {
+ "ansible_command_timeout": self.get_option(
+ "persistent_command_timeout"
+ )
+ }
+ socket_path = start_connection(
+ self._play_context, variables, self._task_uuid
+ )
+ display.vvvv(
+ "local domain socket path is %s" % socket_path,
+ host=self._play_context.remote_addr,
+ )
+ setattr(self, "_socket_path", socket_path)
+ return socket_path
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py
new file mode 100644
index 00000000..8789075a
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py
@@ -0,0 +1,66 @@
+# -*- coding: utf-8 -*-
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+
+class ModuleDocFragment(object):
+
+ # Standard files documentation fragment
+ DOCUMENTATION = r"""options:
+ host:
+ description:
+ - Specifies the DNS host name or address for connecting to the remote device over
+ the specified transport. The value of host is used as the destination address
+ for the transport.
+ type: str
+ required: true
+ port:
+ description:
+ - Specifies the port to use when building the connection to the remote device. The
+ port value will default to port 830.
+ type: int
+ default: 830
+ username:
+ description:
+ - Configures the username to use to authenticate the connection to the remote
+ device. This value is used to authenticate the SSH session. If the value is
+ not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME)
+ will be used instead.
+ type: str
+ password:
+ description:
+ - Specifies the password to use to authenticate the connection to the remote device. This
+ value is used to authenticate the SSH session. If the value is not specified
+ in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD) will
+ be used instead.
+ type: str
+ timeout:
+ description:
+ - Specifies the timeout in seconds for communicating with the network device for
+ either connecting or sending commands. If the timeout is exceeded before the
+ operation is completed, the module will error.
+ type: int
+ default: 10
+ ssh_keyfile:
+ description:
+ - Specifies the SSH key to use to authenticate the connection to the remote device. This
+ value is the path to the key used to authenticate the SSH session. If the value
+ is not specified in the task, the value of environment variable C(ANSIBLE_NET_SSH_KEYFILE)
+ will be used instead.
+ type: path
+ hostkey_verify:
+ description:
+ - If set to C(yes), the ssh host key of the device must match a ssh key present
+ on the host if set to C(no), the ssh host key of the device is not checked.
+ type: bool
+ default: true
+ look_for_keys:
+ description:
+ - Enables looking in the usual locations for the ssh keys (e.g. :file:`~/.ssh/id_*`)
+ type: bool
+ default: true
+notes:
+- For information on using netconf see the :ref:`Platform Options guide using Netconf<netconf_enabled_platform_options>`
+- For more information on using Ansible to manage network devices see the :ref:`Ansible
+ Network Guide <network_guide>`
+"""
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py
new file mode 100644
index 00000000..ad65f6ef
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2019 Ansible, Inc
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+
+class ModuleDocFragment(object):
+
+ # Standard files documentation fragment
+ DOCUMENTATION = r"""options: {}
+notes:
+- This module is supported on C(ansible_network_os) network platforms. See the :ref:`Network
+ Platform Options <platform_options>` for details.
+"""
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py
new file mode 100644
index 00000000..6ae47a73
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py
@@ -0,0 +1,1186 @@
+# (c) 2014, Maciej Delmanowski <drybjed@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+from functools import partial
+import types
+
+try:
+ import netaddr
+except ImportError:
+ # in this case, we'll make the filters return error messages (see bottom)
+ netaddr = None
+else:
+
+ class mac_linux(netaddr.mac_unix):
+ pass
+
+ mac_linux.word_fmt = "%.2x"
+
+from ansible import errors
+
+
+# ---- IP address and network query helpers ----
+def _empty_ipaddr_query(v, vtype):
+ # We don't have any query to process, so just check what type the user
+ # expects, and return the IP address in a correct format
+ if v:
+ if vtype == "address":
+ return str(v.ip)
+ elif vtype == "network":
+ return str(v)
+
+
+def _first_last(v):
+ if v.size == 2:
+ first_usable = int(netaddr.IPAddress(v.first))
+ last_usable = int(netaddr.IPAddress(v.last))
+ return first_usable, last_usable
+ elif v.size > 1:
+ first_usable = int(netaddr.IPAddress(v.first + 1))
+ last_usable = int(netaddr.IPAddress(v.last - 1))
+ return first_usable, last_usable
+
+
+def _6to4_query(v, vtype, value):
+ if v.version == 4:
+
+ if v.size == 1:
+ ipconv = str(v.ip)
+ elif v.size > 1:
+ if v.ip != v.network:
+ ipconv = str(v.ip)
+ else:
+ ipconv = False
+
+ if ipaddr(ipconv, "public"):
+ numbers = list(map(int, ipconv.split(".")))
+
+ try:
+ return "2002:{:02x}{:02x}:{:02x}{:02x}::1/48".format(*numbers)
+ except Exception:
+ return False
+
+ elif v.version == 6:
+ if vtype == "address":
+ if ipaddr(str(v), "2002::/16"):
+ return value
+ elif vtype == "network":
+ if v.ip != v.network:
+ if ipaddr(str(v.ip), "2002::/16"):
+ return value
+ else:
+ return False
+
+
+def _ip_query(v):
+ if v.size == 1:
+ return str(v.ip)
+ if v.size > 1:
+ # /31 networks in netaddr have no broadcast address
+ if v.ip != v.network or not v.broadcast:
+ return str(v.ip)
+
+
+def _gateway_query(v):
+ if v.size > 1:
+ if v.ip != v.network:
+ return str(v.ip) + "/" + str(v.prefixlen)
+
+
+def _address_prefix_query(v):
+ if v.size > 1:
+ if v.ip != v.network:
+ return str(v.ip) + "/" + str(v.prefixlen)
+
+
+def _bool_ipaddr_query(v):
+ if v:
+ return True
+
+
+def _broadcast_query(v):
+ if v.size > 2:
+ return str(v.broadcast)
+
+
+def _cidr_query(v):
+ return str(v)
+
+
+def _cidr_lookup_query(v, iplist, value):
+ try:
+ if v in iplist:
+ return value
+ except Exception:
+ return False
+
+
+def _first_usable_query(v, vtype):
+ if vtype == "address":
+ "Does it make sense to raise an error"
+ raise errors.AnsibleFilterError("Not a network address")
+ elif vtype == "network":
+ if v.size == 2:
+ return str(netaddr.IPAddress(int(v.network)))
+ elif v.size > 1:
+ return str(netaddr.IPAddress(int(v.network) + 1))
+
+
+def _host_query(v):
+ if v.size == 1:
+ return str(v)
+ elif v.size > 1:
+ if v.ip != v.network:
+ return str(v.ip) + "/" + str(v.prefixlen)
+
+
+def _hostmask_query(v):
+ return str(v.hostmask)
+
+
+def _int_query(v, vtype):
+ if vtype == "address":
+ return int(v.ip)
+ elif vtype == "network":
+ return str(int(v.ip)) + "/" + str(int(v.prefixlen))
+
+
+def _ip_prefix_query(v):
+ if v.size == 2:
+ return str(v.ip) + "/" + str(v.prefixlen)
+ elif v.size > 1:
+ if v.ip != v.network:
+ return str(v.ip) + "/" + str(v.prefixlen)
+
+
+def _ip_netmask_query(v):
+ if v.size == 2:
+ return str(v.ip) + " " + str(v.netmask)
+ elif v.size > 1:
+ if v.ip != v.network:
+ return str(v.ip) + " " + str(v.netmask)
+
+
+"""
+def _ip_wildcard_query(v):
+ if v.size == 2:
+ return str(v.ip) + ' ' + str(v.hostmask)
+ elif v.size > 1:
+ if v.ip != v.network:
+ return str(v.ip) + ' ' + str(v.hostmask)
+"""
+
+
+def _ipv4_query(v, value):
+ if v.version == 6:
+ try:
+ return str(v.ipv4())
+ except Exception:
+ return False
+ else:
+ return value
+
+
+def _ipv6_query(v, value):
+ if v.version == 4:
+ return str(v.ipv6())
+ else:
+ return value
+
+
+def _last_usable_query(v, vtype):
+ if vtype == "address":
+ "Does it make sense to raise an error"
+ raise errors.AnsibleFilterError("Not a network address")
+ elif vtype == "network":
+ if v.size > 1:
+ first_usable, last_usable = _first_last(v)
+ return str(netaddr.IPAddress(last_usable))
+
+
+def _link_local_query(v, value):
+ v_ip = netaddr.IPAddress(str(v.ip))
+ if v.version == 4:
+ if ipaddr(str(v_ip), "169.254.0.0/24"):
+ return value
+
+ elif v.version == 6:
+ if ipaddr(str(v_ip), "fe80::/10"):
+ return value
+
+
+def _loopback_query(v, value):
+ v_ip = netaddr.IPAddress(str(v.ip))
+ if v_ip.is_loopback():
+ return value
+
+
+def _multicast_query(v, value):
+ if v.is_multicast():
+ return value
+
+
+def _net_query(v):
+ if v.size > 1:
+ if v.ip == v.network:
+ return str(v.network) + "/" + str(v.prefixlen)
+
+
+def _netmask_query(v):
+ return str(v.netmask)
+
+
+def _network_query(v):
+ """Return the network of a given IP or subnet"""
+ return str(v.network)
+
+
+def _network_id_query(v):
+ """Return the network of a given IP or subnet"""
+ return str(v.network)
+
+
+def _network_netmask_query(v):
+ return str(v.network) + " " + str(v.netmask)
+
+
+def _network_wildcard_query(v):
+ return str(v.network) + " " + str(v.hostmask)
+
+
+def _next_usable_query(v, vtype):
+ if vtype == "address":
+ "Does it make sense to raise an error"
+ raise errors.AnsibleFilterError("Not a network address")
+ elif vtype == "network":
+ if v.size > 1:
+ first_usable, last_usable = _first_last(v)
+ next_ip = int(netaddr.IPAddress(int(v.ip) + 1))
+ if next_ip >= first_usable and next_ip <= last_usable:
+ return str(netaddr.IPAddress(int(v.ip) + 1))
+
+
+def _peer_query(v, vtype):
+ if vtype == "address":
+ raise errors.AnsibleFilterError("Not a network address")
+ elif vtype == "network":
+ if v.size == 2:
+ return str(netaddr.IPAddress(int(v.ip) ^ 1))
+ if v.size == 4:
+ if int(v.ip) % 4 == 0:
+ raise errors.AnsibleFilterError(
+ "Network address of /30 has no peer"
+ )
+ if int(v.ip) % 4 == 3:
+ raise errors.AnsibleFilterError(
+ "Broadcast address of /30 has no peer"
+ )
+ return str(netaddr.IPAddress(int(v.ip) ^ 3))
+ raise errors.AnsibleFilterError("Not a point-to-point network")
+
+
+def _prefix_query(v):
+ return int(v.prefixlen)
+
+
+def _previous_usable_query(v, vtype):
+ if vtype == "address":
+ "Does it make sense to raise an error"
+ raise errors.AnsibleFilterError("Not a network address")
+ elif vtype == "network":
+ if v.size > 1:
+ first_usable, last_usable = _first_last(v)
+ previous_ip = int(netaddr.IPAddress(int(v.ip) - 1))
+ if previous_ip >= first_usable and previous_ip <= last_usable:
+ return str(netaddr.IPAddress(int(v.ip) - 1))
+
+
+def _private_query(v, value):
+ if v.is_private():
+ return value
+
+
+def _public_query(v, value):
+ v_ip = netaddr.IPAddress(str(v.ip))
+ if (
+ v_ip.is_unicast()
+ and not v_ip.is_private()
+ and not v_ip.is_loopback()
+ and not v_ip.is_netmask()
+ and not v_ip.is_hostmask()
+ ):
+ return value
+
+
+def _range_usable_query(v, vtype):
+ if vtype == "address":
+ "Does it make sense to raise an error"
+ raise errors.AnsibleFilterError("Not a network address")
+ elif vtype == "network":
+ if v.size > 1:
+ first_usable, last_usable = _first_last(v)
+ first_usable = str(netaddr.IPAddress(first_usable))
+ last_usable = str(netaddr.IPAddress(last_usable))
+ return "{0}-{1}".format(first_usable, last_usable)
+
+
+def _revdns_query(v):
+ v_ip = netaddr.IPAddress(str(v.ip))
+ return v_ip.reverse_dns
+
+
+def _size_query(v):
+ return v.size
+
+
+def _size_usable_query(v):
+ if v.size == 1:
+ return 0
+ elif v.size == 2:
+ return 2
+ return v.size - 2
+
+
+def _subnet_query(v):
+ return str(v.cidr)
+
+
+def _type_query(v):
+ if v.size == 1:
+ return "address"
+ if v.size > 1:
+ if v.ip != v.network:
+ return "address"
+ else:
+ return "network"
+
+
+def _unicast_query(v, value):
+ if v.is_unicast():
+ return value
+
+
+def _version_query(v):
+ return v.version
+
+
+def _wrap_query(v, vtype, value):
+ if v.version == 6:
+ if vtype == "address":
+ return "[" + str(v.ip) + "]"
+ elif vtype == "network":
+ return "[" + str(v.ip) + "]/" + str(v.prefixlen)
+ else:
+ return value
+
+
+# ---- HWaddr query helpers ----
+def _bare_query(v):
+ v.dialect = netaddr.mac_bare
+ return str(v)
+
+
+def _bool_hwaddr_query(v):
+ if v:
+ return True
+
+
+def _int_hwaddr_query(v):
+ return int(v)
+
+
+def _cisco_query(v):
+ v.dialect = netaddr.mac_cisco
+ return str(v)
+
+
+def _empty_hwaddr_query(v, value):
+ if v:
+ return value
+
+
+def _linux_query(v):
+ v.dialect = mac_linux
+ return str(v)
+
+
+def _postgresql_query(v):
+ v.dialect = netaddr.mac_pgsql
+ return str(v)
+
+
+def _unix_query(v):
+ v.dialect = netaddr.mac_unix
+ return str(v)
+
+
+def _win_query(v):
+ v.dialect = netaddr.mac_eui48
+ return str(v)
+
+
+# ---- IP address and network filters ----
+
+# Returns a minified list of subnets or a single subnet that spans all of
+# the inputs.
+def cidr_merge(value, action="merge"):
+ if not hasattr(value, "__iter__"):
+ raise errors.AnsibleFilterError(
+ "cidr_merge: expected iterable, got " + repr(value)
+ )
+
+ if action == "merge":
+ try:
+ return [str(ip) for ip in netaddr.cidr_merge(value)]
+ except Exception as e:
+ raise errors.AnsibleFilterError(
+ "cidr_merge: error in netaddr:\n%s" % e
+ )
+
+ elif action == "span":
+ # spanning_cidr needs at least two values
+ if len(value) == 0:
+ return None
+ elif len(value) == 1:
+ try:
+ return str(netaddr.IPNetwork(value[0]))
+ except Exception as e:
+ raise errors.AnsibleFilterError(
+ "cidr_merge: error in netaddr:\n%s" % e
+ )
+ else:
+ try:
+ return str(netaddr.spanning_cidr(value))
+ except Exception as e:
+ raise errors.AnsibleFilterError(
+ "cidr_merge: error in netaddr:\n%s" % e
+ )
+
+ else:
+ raise errors.AnsibleFilterError(
+ "cidr_merge: invalid action '%s'" % action
+ )
+
+
+def ipaddr(value, query="", version=False, alias="ipaddr"):
+ """ Check if string is an IP address or network and filter it """
+
+ query_func_extra_args = {
+ "": ("vtype",),
+ "6to4": ("vtype", "value"),
+ "cidr_lookup": ("iplist", "value"),
+ "first_usable": ("vtype",),
+ "int": ("vtype",),
+ "ipv4": ("value",),
+ "ipv6": ("value",),
+ "last_usable": ("vtype",),
+ "link-local": ("value",),
+ "loopback": ("value",),
+ "lo": ("value",),
+ "multicast": ("value",),
+ "next_usable": ("vtype",),
+ "peer": ("vtype",),
+ "previous_usable": ("vtype",),
+ "private": ("value",),
+ "public": ("value",),
+ "unicast": ("value",),
+ "range_usable": ("vtype",),
+ "wrap": ("vtype", "value"),
+ }
+
+ query_func_map = {
+ "": _empty_ipaddr_query,
+ "6to4": _6to4_query,
+ "address": _ip_query,
+ "address/prefix": _address_prefix_query, # deprecate
+ "bool": _bool_ipaddr_query,
+ "broadcast": _broadcast_query,
+ "cidr": _cidr_query,
+ "cidr_lookup": _cidr_lookup_query,
+ "first_usable": _first_usable_query,
+ "gateway": _gateway_query, # deprecate
+ "gw": _gateway_query, # deprecate
+ "host": _host_query,
+ "host/prefix": _address_prefix_query, # deprecate
+ "hostmask": _hostmask_query,
+ "hostnet": _gateway_query, # deprecate
+ "int": _int_query,
+ "ip": _ip_query,
+ "ip/prefix": _ip_prefix_query,
+ "ip_netmask": _ip_netmask_query,
+ # 'ip_wildcard': _ip_wildcard_query, built then could not think of use case
+ "ipv4": _ipv4_query,
+ "ipv6": _ipv6_query,
+ "last_usable": _last_usable_query,
+ "link-local": _link_local_query,
+ "lo": _loopback_query,
+ "loopback": _loopback_query,
+ "multicast": _multicast_query,
+ "net": _net_query,
+ "next_usable": _next_usable_query,
+ "netmask": _netmask_query,
+ "network": _network_query,
+ "network_id": _network_id_query,
+ "network/prefix": _subnet_query,
+ "network_netmask": _network_netmask_query,
+ "network_wildcard": _network_wildcard_query,
+ "peer": _peer_query,
+ "prefix": _prefix_query,
+ "previous_usable": _previous_usable_query,
+ "private": _private_query,
+ "public": _public_query,
+ "range_usable": _range_usable_query,
+ "revdns": _revdns_query,
+ "router": _gateway_query, # deprecate
+ "size": _size_query,
+ "size_usable": _size_usable_query,
+ "subnet": _subnet_query,
+ "type": _type_query,
+ "unicast": _unicast_query,
+ "v4": _ipv4_query,
+ "v6": _ipv6_query,
+ "version": _version_query,
+ "wildcard": _hostmask_query,
+ "wrap": _wrap_query,
+ }
+
+ vtype = None
+
+ if not value:
+ return False
+
+ elif value is True:
+ return False
+
+ # Check if value is a list and parse each element
+ elif isinstance(value, (list, tuple, types.GeneratorType)):
+
+ _ret = []
+ for element in value:
+ if ipaddr(element, str(query), version):
+ _ret.append(ipaddr(element, str(query), version))
+
+ if _ret:
+ return _ret
+ else:
+ return list()
+
+ # Check if value is a number and convert it to an IP address
+ elif str(value).isdigit():
+
+ # We don't know what IP version to assume, so let's check IPv4 first,
+ # then IPv6
+ try:
+ if (not version) or (version and version == 4):
+ v = netaddr.IPNetwork("0.0.0.0/0")
+ v.value = int(value)
+ v.prefixlen = 32
+ elif version and version == 6:
+ v = netaddr.IPNetwork("::/0")
+ v.value = int(value)
+ v.prefixlen = 128
+
+ # IPv4 didn't work the first time, so it definitely has to be IPv6
+ except Exception:
+ try:
+ v = netaddr.IPNetwork("::/0")
+ v.value = int(value)
+ v.prefixlen = 128
+
+ # The value is too big for IPv6. Are you a nanobot?
+ except Exception:
+ return False
+
+ # We got an IP address, let's mark it as such
+ value = str(v)
+ vtype = "address"
+
+ # value has not been recognized, check if it's a valid IP string
+ else:
+ try:
+ v = netaddr.IPNetwork(value)
+
+ # value is a valid IP string, check if user specified
+ # CIDR prefix or just an IP address, this will indicate default
+ # output format
+ try:
+ address, prefix = value.split("/")
+ vtype = "network"
+ except Exception:
+ vtype = "address"
+
+ # value hasn't been recognized, maybe it's a numerical CIDR?
+ except Exception:
+ try:
+ address, prefix = value.split("/")
+ address.isdigit()
+ address = int(address)
+ prefix.isdigit()
+ prefix = int(prefix)
+
+ # It's not numerical CIDR, give up
+ except Exception:
+ return False
+
+ # It is something, so let's try and build a CIDR from the parts
+ try:
+ v = netaddr.IPNetwork("0.0.0.0/0")
+ v.value = address
+ v.prefixlen = prefix
+
+ # It's not a valid IPv4 CIDR
+ except Exception:
+ try:
+ v = netaddr.IPNetwork("::/0")
+ v.value = address
+ v.prefixlen = prefix
+
+ # It's not a valid IPv6 CIDR. Give up.
+ except Exception:
+ return False
+
+ # We have a valid CIDR, so let's write it in correct format
+ value = str(v)
+ vtype = "network"
+
+ # We have a query string but it's not in the known query types. Check if
+ # that string is a valid subnet, if so, we can check later if given IP
+ # address/network is inside that specific subnet
+ try:
+ # ?? 6to4 and link-local were True here before. Should they still?
+ if (
+ query
+ and (query not in query_func_map or query == "cidr_lookup")
+ and not str(query).isdigit()
+ and ipaddr(query, "network")
+ ):
+ iplist = netaddr.IPSet([netaddr.IPNetwork(query)])
+ query = "cidr_lookup"
+ except Exception:
+ pass
+
+ # This code checks if value maches the IP version the user wants, ie. if
+ # it's any version ("ipaddr()"), IPv4 ("ipv4()") or IPv6 ("ipv6()")
+ # If version does not match, return False
+ if version and v.version != version:
+ return False
+
+ extras = []
+ for arg in query_func_extra_args.get(query, tuple()):
+ extras.append(locals()[arg])
+ try:
+ return query_func_map[query](v, *extras)
+ except KeyError:
+ try:
+ float(query)
+ if v.size == 1:
+ if vtype == "address":
+ return str(v.ip)
+ elif vtype == "network":
+ return str(v)
+
+ elif v.size > 1:
+ try:
+ return str(v[query]) + "/" + str(v.prefixlen)
+ except Exception:
+ return False
+
+ else:
+ return value
+
+ except Exception:
+ raise errors.AnsibleFilterError(
+ alias + ": unknown filter type: %s" % query
+ )
+
+ return False
+
+
+def ipmath(value, amount):
+ try:
+ if "/" in value:
+ ip = netaddr.IPNetwork(value).ip
+ else:
+ ip = netaddr.IPAddress(value)
+ except (netaddr.AddrFormatError, ValueError):
+ msg = "You must pass a valid IP address; {0} is invalid".format(value)
+ raise errors.AnsibleFilterError(msg)
+
+ if not isinstance(amount, int):
+ msg = (
+ "You must pass an integer for arithmetic; "
+ "{0} is not a valid integer"
+ ).format(amount)
+ raise errors.AnsibleFilterError(msg)
+
+ return str(ip + amount)
+
+
+def ipwrap(value, query=""):
+ try:
+ if isinstance(value, (list, tuple, types.GeneratorType)):
+ _ret = []
+ for element in value:
+ if ipaddr(element, query, version=False, alias="ipwrap"):
+ _ret.append(ipaddr(element, "wrap"))
+ else:
+ _ret.append(element)
+
+ return _ret
+ else:
+ _ret = ipaddr(value, query, version=False, alias="ipwrap")
+ if _ret:
+ return ipaddr(_ret, "wrap")
+ else:
+ return value
+
+ except Exception:
+ return value
+
+
+def ipv4(value, query=""):
+ return ipaddr(value, query, version=4, alias="ipv4")
+
+
+def ipv6(value, query=""):
+ return ipaddr(value, query, version=6, alias="ipv6")
+
+
+# Split given subnet into smaller subnets or find out the biggest subnet of
+# a given IP address with given CIDR prefix
+# Usage:
+#
+# - address or address/prefix | ipsubnet
+# returns CIDR subnet of a given input
+#
+# - address/prefix | ipsubnet(cidr)
+# returns number of possible subnets for given CIDR prefix
+#
+# - address/prefix | ipsubnet(cidr, index)
+# returns new subnet with given CIDR prefix
+#
+# - address | ipsubnet(cidr)
+# returns biggest subnet with given CIDR prefix that address belongs to
+#
+# - address | ipsubnet(cidr, index)
+# returns next indexed subnet which contains given address
+#
+# - address/prefix | ipsubnet(subnet/prefix)
+# return the index of the subnet in the subnet
+def ipsubnet(value, query="", index="x"):
+ """ Manipulate IPv4/IPv6 subnets """
+
+ try:
+ vtype = ipaddr(value, "type")
+ if vtype == "address":
+ v = ipaddr(value, "cidr")
+ elif vtype == "network":
+ v = ipaddr(value, "subnet")
+
+ value = netaddr.IPNetwork(v)
+ except Exception:
+ return False
+ query_string = str(query)
+ if not query:
+ return str(value)
+
+ elif query_string.isdigit():
+ vsize = ipaddr(v, "size")
+ query = int(query)
+
+ try:
+ float(index)
+ index = int(index)
+
+ if vsize > 1:
+ try:
+ return str(list(value.subnet(query))[index])
+ except Exception:
+ return False
+
+ elif vsize == 1:
+ try:
+ return str(value.supernet(query)[index])
+ except Exception:
+ return False
+
+ except Exception:
+ if vsize > 1:
+ try:
+ return str(len(list(value.subnet(query))))
+ except Exception:
+ return False
+
+ elif vsize == 1:
+ try:
+ return str(value.supernet(query)[0])
+ except Exception:
+ return False
+
+ elif query_string:
+ vtype = ipaddr(query, "type")
+ if vtype == "address":
+ v = ipaddr(query, "cidr")
+ elif vtype == "network":
+ v = ipaddr(query, "subnet")
+ else:
+ msg = "You must pass a valid subnet or IP address; {0} is invalid".format(
+ query_string
+ )
+ raise errors.AnsibleFilterError(msg)
+ query = netaddr.IPNetwork(v)
+ for i, subnet in enumerate(query.subnet(value.prefixlen), 1):
+ if subnet == value:
+ return str(i)
+ msg = "{0} is not in the subnet {1}".format(value.cidr, query.cidr)
+ raise errors.AnsibleFilterError(msg)
+ return False
+
+
+# Returns the nth host within a network described by value.
+# Usage:
+#
+# - address or address/prefix | nthhost(nth)
+# returns the nth host within the given network
+def nthhost(value, query=""):
+ """ Get the nth host within a given network """
+ try:
+ vtype = ipaddr(value, "type")
+ if vtype == "address":
+ v = ipaddr(value, "cidr")
+ elif vtype == "network":
+ v = ipaddr(value, "subnet")
+
+ value = netaddr.IPNetwork(v)
+ except Exception:
+ return False
+
+ if not query:
+ return False
+
+ try:
+ nth = int(query)
+ if value.size > nth:
+ return value[nth]
+
+ except ValueError:
+ return False
+
+ return False
+
+
+# Returns the next nth usable ip within a network described by value.
+def next_nth_usable(value, offset):
+ try:
+ vtype = ipaddr(value, "type")
+ if vtype == "address":
+ v = ipaddr(value, "cidr")
+ elif vtype == "network":
+ v = ipaddr(value, "subnet")
+
+ v = netaddr.IPNetwork(v)
+ except Exception:
+ return False
+
+ if type(offset) != int:
+ raise errors.AnsibleFilterError("Must pass in an integer")
+ if v.size > 1:
+ first_usable, last_usable = _first_last(v)
+ nth_ip = int(netaddr.IPAddress(int(v.ip) + offset))
+ if nth_ip >= first_usable and nth_ip <= last_usable:
+ return str(netaddr.IPAddress(int(v.ip) + offset))
+
+
+# Returns the previous nth usable ip within a network described by value.
+def previous_nth_usable(value, offset):
+ try:
+ vtype = ipaddr(value, "type")
+ if vtype == "address":
+ v = ipaddr(value, "cidr")
+ elif vtype == "network":
+ v = ipaddr(value, "subnet")
+
+ v = netaddr.IPNetwork(v)
+ except Exception:
+ return False
+
+ if type(offset) != int:
+ raise errors.AnsibleFilterError("Must pass in an integer")
+ if v.size > 1:
+ first_usable, last_usable = _first_last(v)
+ nth_ip = int(netaddr.IPAddress(int(v.ip) - offset))
+ if nth_ip >= first_usable and nth_ip <= last_usable:
+ return str(netaddr.IPAddress(int(v.ip) - offset))
+
+
+def _range_checker(ip_check, first, last):
+ """
+ Tests whether an ip address is within the bounds of the first and last address.
+
+ :param ip_check: The ip to test if it is within first and last.
+ :param first: The first IP in the range to test against.
+ :param last: The last IP in the range to test against.
+
+ :return: bool
+ """
+ if ip_check >= first and ip_check <= last:
+ return True
+ else:
+ return False
+
+
+def _address_normalizer(value):
+ """
+ Used to validate an address or network type and return it in a consistent format.
+ This is being used for future use cases not currently available such as an address range.
+
+ :param value: The string representation of an address or network.
+
+ :return: The address or network in the normalized form.
+ """
+ try:
+ vtype = ipaddr(value, "type")
+ if vtype == "address" or vtype == "network":
+ v = ipaddr(value, "subnet")
+ except Exception:
+ return False
+
+ return v
+
+
+def network_in_usable(value, test):
+ """
+ Checks whether 'test' is a useable address or addresses in 'value'
+
+ :param: value: The string representation of an address or network to test against.
+ :param test: The string representation of an address or network to validate if it is within the range of 'value'.
+
+ :return: bool
+ """
+ # normalize value and test variables into an ipaddr
+ v = _address_normalizer(value)
+ w = _address_normalizer(test)
+
+ # get first and last addresses as integers to compare value and test; or cathes value when case is /32
+ v_first = ipaddr(ipaddr(v, "first_usable") or ipaddr(v, "address"), "int")
+ v_last = ipaddr(ipaddr(v, "last_usable") or ipaddr(v, "address"), "int")
+ w_first = ipaddr(ipaddr(w, "network") or ipaddr(w, "address"), "int")
+ w_last = ipaddr(ipaddr(w, "broadcast") or ipaddr(w, "address"), "int")
+
+ if _range_checker(w_first, v_first, v_last) and _range_checker(
+ w_last, v_first, v_last
+ ):
+ return True
+ else:
+ return False
+
+
+def network_in_network(value, test):
+ """
+ Checks whether the 'test' address or addresses are in 'value', including broadcast and network
+
+ :param: value: The network address or range to test against.
+ :param test: The address or network to validate if it is within the range of 'value'.
+
+ :return: bool
+ """
+ # normalize value and test variables into an ipaddr
+ v = _address_normalizer(value)
+ w = _address_normalizer(test)
+
+ # get first and last addresses as integers to compare value and test; or cathes value when case is /32
+ v_first = ipaddr(ipaddr(v, "network") or ipaddr(v, "address"), "int")
+ v_last = ipaddr(ipaddr(v, "broadcast") or ipaddr(v, "address"), "int")
+ w_first = ipaddr(ipaddr(w, "network") or ipaddr(w, "address"), "int")
+ w_last = ipaddr(ipaddr(w, "broadcast") or ipaddr(w, "address"), "int")
+
+ if _range_checker(w_first, v_first, v_last) and _range_checker(
+ w_last, v_first, v_last
+ ):
+ return True
+ else:
+ return False
+
+
+def reduce_on_network(value, network):
+ """
+ Reduces a list of addresses to only the addresses that match a given network.
+
+ :param: value: The list of addresses to filter on.
+ :param: network: The network to validate against.
+
+ :return: The reduced list of addresses.
+ """
+ # normalize network variable into an ipaddr
+ n = _address_normalizer(network)
+
+ # get first and last addresses as integers to compare value and test; or cathes value when case is /32
+ n_first = ipaddr(ipaddr(n, "network") or ipaddr(n, "address"), "int")
+ n_last = ipaddr(ipaddr(n, "broadcast") or ipaddr(n, "address"), "int")
+
+ # create an empty list to fill and return
+ r = []
+
+ for address in value:
+ # normalize address variables into an ipaddr
+ a = _address_normalizer(address)
+
+ # get first and last addresses as integers to compare value and test; or cathes value when case is /32
+ a_first = ipaddr(ipaddr(a, "network") or ipaddr(a, "address"), "int")
+ a_last = ipaddr(ipaddr(a, "broadcast") or ipaddr(a, "address"), "int")
+
+ if _range_checker(a_first, n_first, n_last) and _range_checker(
+ a_last, n_first, n_last
+ ):
+ r.append(address)
+
+ return r
+
+
+# Returns the SLAAC address within a network for a given HW/MAC address.
+# Usage:
+#
+# - prefix | slaac(mac)
+def slaac(value, query=""):
+ """ Get the SLAAC address within given network """
+ try:
+ vtype = ipaddr(value, "type")
+ if vtype == "address":
+ v = ipaddr(value, "cidr")
+ elif vtype == "network":
+ v = ipaddr(value, "subnet")
+
+ if ipaddr(value, "version") != 6:
+ return False
+
+ value = netaddr.IPNetwork(v)
+ except Exception:
+ return False
+
+ if not query:
+ return False
+
+ try:
+ mac = hwaddr(query, alias="slaac")
+
+ eui = netaddr.EUI(mac)
+ except Exception:
+ return False
+
+ return eui.ipv6(value.network)
+
+
+# ---- HWaddr / MAC address filters ----
+def hwaddr(value, query="", alias="hwaddr"):
+ """ Check if string is a HW/MAC address and filter it """
+
+ query_func_extra_args = {"": ("value",)}
+
+ query_func_map = {
+ "": _empty_hwaddr_query,
+ "bare": _bare_query,
+ "bool": _bool_hwaddr_query,
+ "int": _int_hwaddr_query,
+ "cisco": _cisco_query,
+ "eui48": _win_query,
+ "linux": _linux_query,
+ "pgsql": _postgresql_query,
+ "postgresql": _postgresql_query,
+ "psql": _postgresql_query,
+ "unix": _unix_query,
+ "win": _win_query,
+ }
+
+ try:
+ v = netaddr.EUI(value)
+ except Exception:
+ if query and query != "bool":
+ raise errors.AnsibleFilterError(
+ alias + ": not a hardware address: %s" % value
+ )
+
+ extras = []
+ for arg in query_func_extra_args.get(query, tuple()):
+ extras.append(locals()[arg])
+ try:
+ return query_func_map[query](v, *extras)
+ except KeyError:
+ raise errors.AnsibleFilterError(
+ alias + ": unknown filter type: %s" % query
+ )
+
+ return False
+
+
+def macaddr(value, query=""):
+ return hwaddr(value, query, alias="macaddr")
+
+
+def _need_netaddr(f_name, *args, **kwargs):
+ raise errors.AnsibleFilterError(
+ "The %s filter requires python's netaddr be "
+ "installed on the ansible controller" % f_name
+ )
+
+
+def ip4_hex(arg, delimiter=""):
+ """ Convert an IPv4 address to Hexadecimal notation """
+ numbers = list(map(int, arg.split(".")))
+ return "{0:02x}{sep}{1:02x}{sep}{2:02x}{sep}{3:02x}".format(
+ *numbers, sep=delimiter
+ )
+
+
+# ---- Ansible filters ----
+class FilterModule(object):
+ """ IP address and network manipulation filters """
+
+ filter_map = {
+ # IP addresses and networks
+ "cidr_merge": cidr_merge,
+ "ipaddr": ipaddr,
+ "ipmath": ipmath,
+ "ipwrap": ipwrap,
+ "ip4_hex": ip4_hex,
+ "ipv4": ipv4,
+ "ipv6": ipv6,
+ "ipsubnet": ipsubnet,
+ "next_nth_usable": next_nth_usable,
+ "network_in_network": network_in_network,
+ "network_in_usable": network_in_usable,
+ "reduce_on_network": reduce_on_network,
+ "nthhost": nthhost,
+ "previous_nth_usable": previous_nth_usable,
+ "slaac": slaac,
+ # MAC / HW addresses
+ "hwaddr": hwaddr,
+ "macaddr": macaddr,
+ }
+
+ def filters(self):
+ if netaddr:
+ return self.filter_map
+ else:
+ # Need to install python's netaddr for these filters to work
+ return dict(
+ (f, partial(_need_netaddr, f)) for f in self.filter_map
+ )
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py
new file mode 100644
index 00000000..f99e6e76
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py
@@ -0,0 +1,531 @@
+#
+# {c) 2017 Red Hat, Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import re
+import os
+import traceback
+import string
+
+from xml.etree.ElementTree import fromstring
+
+from ansible.module_utils._text import to_native, to_text
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ Template,
+)
+from ansible.module_utils.six import iteritems, string_types
+from ansible.module_utils.common._collections_compat import Mapping
+from ansible.errors import AnsibleError, AnsibleFilterError
+from ansible.utils.display import Display
+from ansible.utils.encrypt import passlib_or_crypt, random_password
+
+try:
+ import yaml
+
+ HAS_YAML = True
+except ImportError:
+ HAS_YAML = False
+
+try:
+ import textfsm
+
+ HAS_TEXTFSM = True
+except ImportError:
+ HAS_TEXTFSM = False
+
+display = Display()
+
+
+def re_matchall(regex, value):
+ objects = list()
+ for match in re.findall(regex.pattern, value, re.M):
+ obj = {}
+ if regex.groupindex:
+ for name, index in iteritems(regex.groupindex):
+ if len(regex.groupindex) == 1:
+ obj[name] = match
+ else:
+ obj[name] = match[index - 1]
+ objects.append(obj)
+ return objects
+
+
+def re_search(regex, value):
+ obj = {}
+ match = regex.search(value, re.M)
+ if match:
+ items = list(match.groups())
+ if regex.groupindex:
+ for name, index in iteritems(regex.groupindex):
+ obj[name] = items[index - 1]
+ return obj
+
+
+def parse_cli(output, tmpl):
+ if not isinstance(output, string_types):
+ raise AnsibleError(
+ "parse_cli input should be a string, but was given a input of %s"
+ % (type(output))
+ )
+
+ if not os.path.exists(tmpl):
+ raise AnsibleError("unable to locate parse_cli template: %s" % tmpl)
+
+ try:
+ template = Template()
+ except ImportError as exc:
+ raise AnsibleError(to_native(exc))
+
+ with open(tmpl) as tmpl_fh:
+ tmpl_content = tmpl_fh.read()
+
+ spec = yaml.safe_load(tmpl_content)
+ obj = {}
+
+ for name, attrs in iteritems(spec["keys"]):
+ value = attrs["value"]
+
+ try:
+ variables = spec.get("vars", {})
+ value = template(value, variables)
+ except Exception:
+ pass
+
+ if "start_block" in attrs and "end_block" in attrs:
+ start_block = re.compile(attrs["start_block"])
+ end_block = re.compile(attrs["end_block"])
+
+ blocks = list()
+ lines = None
+ block_started = False
+
+ for line in output.split("\n"):
+ match_start = start_block.match(line)
+ match_end = end_block.match(line)
+
+ if match_start:
+ lines = list()
+ lines.append(line)
+ block_started = True
+
+ elif match_end:
+ if lines:
+ lines.append(line)
+ blocks.append("\n".join(lines))
+ block_started = False
+
+ elif block_started:
+ if lines:
+ lines.append(line)
+
+ regex_items = [re.compile(r) for r in attrs["items"]]
+ objects = list()
+
+ for block in blocks:
+ if isinstance(value, Mapping) and "key" not in value:
+ items = list()
+ for regex in regex_items:
+ match = regex.search(block)
+ if match:
+ item_values = match.groupdict()
+ item_values["match"] = list(match.groups())
+ items.append(item_values)
+ else:
+ items.append(None)
+
+ obj = {}
+ for k, v in iteritems(value):
+ try:
+ obj[k] = template(
+ v, {"item": items}, fail_on_undefined=False
+ )
+ except Exception:
+ obj[k] = None
+ objects.append(obj)
+
+ elif isinstance(value, Mapping):
+ items = list()
+ for regex in regex_items:
+ match = regex.search(block)
+ if match:
+ item_values = match.groupdict()
+ item_values["match"] = list(match.groups())
+ items.append(item_values)
+ else:
+ items.append(None)
+
+ key = template(value["key"], {"item": items})
+ values = dict(
+ [
+ (k, template(v, {"item": items}))
+ for k, v in iteritems(value["values"])
+ ]
+ )
+ objects.append({key: values})
+
+ return objects
+
+ elif "items" in attrs:
+ regexp = re.compile(attrs["items"])
+ when = attrs.get("when")
+ conditional = (
+ "{%% if %s %%}True{%% else %%}False{%% endif %%}" % when
+ )
+
+ if isinstance(value, Mapping) and "key" not in value:
+ values = list()
+
+ for item in re_matchall(regexp, output):
+ entry = {}
+
+ for item_key, item_value in iteritems(value):
+ entry[item_key] = template(item_value, {"item": item})
+
+ if when:
+ if template(conditional, {"item": entry}):
+ values.append(entry)
+ else:
+ values.append(entry)
+
+ obj[name] = values
+
+ elif isinstance(value, Mapping):
+ values = dict()
+
+ for item in re_matchall(regexp, output):
+ entry = {}
+
+ for item_key, item_value in iteritems(value["values"]):
+ entry[item_key] = template(item_value, {"item": item})
+
+ key = template(value["key"], {"item": item})
+
+ if when:
+ if template(
+ conditional, {"item": {"key": key, "value": entry}}
+ ):
+ values[key] = entry
+ else:
+ values[key] = entry
+
+ obj[name] = values
+
+ else:
+ item = re_search(regexp, output)
+ obj[name] = template(value, {"item": item})
+
+ else:
+ obj[name] = value
+
+ return obj
+
+
+def parse_cli_textfsm(value, template):
+ if not HAS_TEXTFSM:
+ raise AnsibleError(
+ "parse_cli_textfsm filter requires TextFSM library to be installed"
+ )
+
+ if not isinstance(value, string_types):
+ raise AnsibleError(
+ "parse_cli_textfsm input should be a string, but was given a input of %s"
+ % (type(value))
+ )
+
+ if not os.path.exists(template):
+ raise AnsibleError(
+ "unable to locate parse_cli_textfsm template: %s" % template
+ )
+
+ try:
+ template = open(template)
+ except IOError as exc:
+ raise AnsibleError(to_native(exc))
+
+ re_table = textfsm.TextFSM(template)
+ fsm_results = re_table.ParseText(value)
+
+ results = list()
+ for item in fsm_results:
+ results.append(dict(zip(re_table.header, item)))
+
+ return results
+
+
+def _extract_param(template, root, attrs, value):
+
+ key = None
+ when = attrs.get("when")
+ conditional = "{%% if %s %%}True{%% else %%}False{%% endif %%}" % when
+ param_to_xpath_map = attrs["items"]
+
+ if isinstance(value, Mapping):
+ key = value.get("key", None)
+ if key:
+ value = value["values"]
+
+ entries = dict() if key else list()
+
+ for element in root.findall(attrs["top"]):
+ entry = dict()
+ item_dict = dict()
+ for param, param_xpath in iteritems(param_to_xpath_map):
+ fields = None
+ try:
+ fields = element.findall(param_xpath)
+ except Exception:
+ display.warning(
+ "Failed to evaluate value of '%s' with XPath '%s'.\nUnexpected error: %s."
+ % (param, param_xpath, traceback.format_exc())
+ )
+
+ tags = param_xpath.split("/")
+
+ # check if xpath ends with attribute.
+ # If yes set attribute key/value dict to param value in case attribute matches
+ # else if it is a normal xpath assign matched element text value.
+ if len(tags) and tags[-1].endswith("]"):
+ if fields:
+ if len(fields) > 1:
+ item_dict[param] = [field.attrib for field in fields]
+ else:
+ item_dict[param] = fields[0].attrib
+ else:
+ item_dict[param] = {}
+ else:
+ if fields:
+ if len(fields) > 1:
+ item_dict[param] = [field.text for field in fields]
+ else:
+ item_dict[param] = fields[0].text
+ else:
+ item_dict[param] = None
+
+ if isinstance(value, Mapping):
+ for item_key, item_value in iteritems(value):
+ entry[item_key] = template(item_value, {"item": item_dict})
+ else:
+ entry = template(value, {"item": item_dict})
+
+ if key:
+ expanded_key = template(key, {"item": item_dict})
+ if when:
+ if template(
+ conditional,
+ {"item": {"key": expanded_key, "value": entry}},
+ ):
+ entries[expanded_key] = entry
+ else:
+ entries[expanded_key] = entry
+ else:
+ if when:
+ if template(conditional, {"item": entry}):
+ entries.append(entry)
+ else:
+ entries.append(entry)
+
+ return entries
+
+
+def parse_xml(output, tmpl):
+ if not os.path.exists(tmpl):
+ raise AnsibleError("unable to locate parse_xml template: %s" % tmpl)
+
+ if not isinstance(output, string_types):
+ raise AnsibleError(
+ "parse_xml works on string input, but given input of : %s"
+ % type(output)
+ )
+
+ root = fromstring(output)
+ try:
+ template = Template()
+ except ImportError as exc:
+ raise AnsibleError(to_native(exc))
+
+ with open(tmpl) as tmpl_fh:
+ tmpl_content = tmpl_fh.read()
+
+ spec = yaml.safe_load(tmpl_content)
+ obj = {}
+
+ for name, attrs in iteritems(spec["keys"]):
+ value = attrs["value"]
+
+ try:
+ variables = spec.get("vars", {})
+ value = template(value, variables)
+ except Exception:
+ pass
+
+ if "items" in attrs:
+ obj[name] = _extract_param(template, root, attrs, value)
+ else:
+ obj[name] = value
+
+ return obj
+
+
+def type5_pw(password, salt=None):
+ if not isinstance(password, string_types):
+ raise AnsibleFilterError(
+ "type5_pw password input should be a string, but was given a input of %s"
+ % (type(password).__name__)
+ )
+
+ salt_chars = u"".join(
+ (to_text(string.ascii_letters), to_text(string.digits), u"./")
+ )
+ if salt is not None and not isinstance(salt, string_types):
+ raise AnsibleFilterError(
+ "type5_pw salt input should be a string, but was given a input of %s"
+ % (type(salt).__name__)
+ )
+ elif not salt:
+ salt = random_password(length=4, chars=salt_chars)
+ elif not set(salt) <= set(salt_chars):
+ raise AnsibleFilterError(
+ "type5_pw salt used inproper characters, must be one of %s"
+ % (salt_chars)
+ )
+
+ encrypted_password = passlib_or_crypt(password, "md5_crypt", salt=salt)
+
+ return encrypted_password
+
+
+def hash_salt(password):
+
+ split_password = password.split("$")
+ if len(split_password) != 4:
+ raise AnsibleFilterError(
+ "Could not parse salt out password correctly from {0}".format(
+ password
+ )
+ )
+ else:
+ return split_password[2]
+
+
+def comp_type5(
+ unencrypted_password, encrypted_password, return_original=False
+):
+
+ salt = hash_salt(encrypted_password)
+ if type5_pw(unencrypted_password, salt) == encrypted_password:
+ if return_original is True:
+ return encrypted_password
+ else:
+ return True
+ return False
+
+
+def vlan_parser(vlan_list, first_line_len=48, other_line_len=44):
+
+ """
+ Input: Unsorted list of vlan integers
+ Output: Sorted string list of integers according to IOS-like vlan list rules
+
+ 1. Vlans are listed in ascending order
+ 2. Runs of 3 or more consecutive vlans are listed with a dash
+ 3. The first line of the list can be first_line_len characters long
+ 4. Subsequent list lines can be other_line_len characters
+ """
+
+ # Sort and remove duplicates
+ sorted_list = sorted(set(vlan_list))
+
+ if sorted_list[0] < 1 or sorted_list[-1] > 4094:
+ raise AnsibleFilterError("Valid VLAN range is 1-4094")
+
+ parse_list = []
+ idx = 0
+ while idx < len(sorted_list):
+ start = idx
+ end = start
+ while end < len(sorted_list) - 1:
+ if sorted_list[end + 1] - sorted_list[end] == 1:
+ end += 1
+ else:
+ break
+
+ if start == end:
+ # Single VLAN
+ parse_list.append(str(sorted_list[idx]))
+ elif start + 1 == end:
+ # Run of 2 VLANs
+ parse_list.append(str(sorted_list[start]))
+ parse_list.append(str(sorted_list[end]))
+ else:
+ # Run of 3 or more VLANs
+ parse_list.append(
+ str(sorted_list[start]) + "-" + str(sorted_list[end])
+ )
+ idx = end + 1
+
+ line_count = 0
+ result = [""]
+ for vlans in parse_list:
+ # First line (" switchport trunk allowed vlan ")
+ if line_count == 0:
+ if len(result[line_count] + vlans) > first_line_len:
+ result.append("")
+ line_count += 1
+ result[line_count] += vlans + ","
+ else:
+ result[line_count] += vlans + ","
+
+ # Subsequent lines (" switchport trunk allowed vlan add ")
+ else:
+ if len(result[line_count] + vlans) > other_line_len:
+ result.append("")
+ line_count += 1
+ result[line_count] += vlans + ","
+ else:
+ result[line_count] += vlans + ","
+
+ # Remove trailing orphan commas
+ for idx in range(0, len(result)):
+ result[idx] = result[idx].rstrip(",")
+
+ # Sometimes text wraps to next line, but there are no remaining VLANs
+ if "" in result:
+ result.remove("")
+
+ return result
+
+
+class FilterModule(object):
+ """Filters for working with output from network devices"""
+
+ filter_map = {
+ "parse_cli": parse_cli,
+ "parse_cli_textfsm": parse_cli_textfsm,
+ "parse_xml": parse_xml,
+ "type5_pw": type5_pw,
+ "hash_salt": hash_salt,
+ "comp_type5": comp_type5,
+ "vlan_parser": vlan_parser,
+ }
+
+ def filters(self):
+ return self.filter_map
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py
new file mode 100644
index 00000000..8afb3e5e
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py
@@ -0,0 +1,91 @@
+# Copyright (c) 2018 Cisco and/or its affiliates.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+DOCUMENTATION = """author: Ansible Networking Team
+httpapi: restconf
+short_description: HttpApi Plugin for devices supporting Restconf API
+description:
+- This HttpApi plugin provides methods to connect to Restconf API endpoints.
+options:
+ root_path:
+ type: str
+ description:
+ - Specifies the location of the Restconf root.
+ default: /restconf
+ vars:
+ - name: ansible_httpapi_restconf_root
+"""
+
+import json
+
+from ansible.module_utils._text import to_text
+from ansible.module_utils.connection import ConnectionError
+from ansible.module_utils.six.moves.urllib.error import HTTPError
+from ansible.plugins.httpapi import HttpApiBase
+
+
+CONTENT_TYPE = "application/yang-data+json"
+
+
+class HttpApi(HttpApiBase):
+ def send_request(self, data, **message_kwargs):
+ if data:
+ data = json.dumps(data)
+
+ path = "/".join(
+ [
+ self.get_option("root_path").rstrip("/"),
+ message_kwargs.get("path", "").lstrip("/"),
+ ]
+ )
+
+ headers = {
+ "Content-Type": message_kwargs.get("content_type") or CONTENT_TYPE,
+ "Accept": message_kwargs.get("accept") or CONTENT_TYPE,
+ }
+ response, response_data = self.connection.send(
+ path, data, headers=headers, method=message_kwargs.get("method")
+ )
+
+ return handle_response(response, response_data)
+
+
+def handle_response(response, response_data):
+ try:
+ response_data = json.loads(response_data.read())
+ except ValueError:
+ response_data = response_data.read()
+
+ if isinstance(response, HTTPError):
+ if response_data:
+ if "errors" in response_data:
+ errors = response_data["errors"]["error"]
+ error_text = "\n".join(
+ (error["error-message"] for error in errors)
+ )
+ else:
+ error_text = response_data
+
+ raise ConnectionError(error_text, code=response.code)
+ raise ConnectionError(to_text(response), code=response.code)
+
+ return response_data
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py
new file mode 100644
index 00000000..dc0a19f7
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py
@@ -0,0 +1,2578 @@
+# -*- coding: utf-8 -*-
+
+# This code is part of Ansible, but is an independent component.
+# This particular file, and this file only, is based on
+# Lib/ipaddress.py of cpython
+# It is licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
+#
+# 1. This LICENSE AGREEMENT is between the Python Software Foundation
+# ("PSF"), and the Individual or Organization ("Licensee") accessing and
+# otherwise using this software ("Python") in source or binary form and
+# its associated documentation.
+#
+# 2. Subject to the terms and conditions of this License Agreement, PSF hereby
+# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
+# analyze, test, perform and/or display publicly, prepare derivative works,
+# distribute, and otherwise use Python alone or in any derivative version,
+# provided, however, that PSF's License Agreement and PSF's notice of copyright,
+# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
+# 2011, 2012, 2013, 2014, 2015 Python Software Foundation; All Rights Reserved"
+# are retained in Python alone or in any derivative version prepared by Licensee.
+#
+# 3. In the event Licensee prepares a derivative work that is based on
+# or incorporates Python or any part thereof, and wants to make
+# the derivative work available to others as provided herein, then
+# Licensee hereby agrees to include in any such work a brief summary of
+# the changes made to Python.
+#
+# 4. PSF is making Python available to Licensee on an "AS IS"
+# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
+# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
+# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
+# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
+# INFRINGE ANY THIRD PARTY RIGHTS.
+#
+# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
+# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
+# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
+# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
+#
+# 6. This License Agreement will automatically terminate upon a material
+# breach of its terms and conditions.
+#
+# 7. Nothing in this License Agreement shall be deemed to create any
+# relationship of agency, partnership, or joint venture between PSF and
+# Licensee. This License Agreement does not grant permission to use PSF
+# trademarks or trade name in a trademark sense to endorse or promote
+# products or services of Licensee, or any third party.
+#
+# 8. By copying, installing or otherwise using Python, Licensee
+# agrees to be bound by the terms and conditions of this License
+# Agreement.
+
+# Copyright 2007 Google Inc.
+# Licensed to PSF under a Contributor Agreement.
+
+"""A fast, lightweight IPv4/IPv6 manipulation library in Python.
+
+This library is used to create/poke/manipulate IPv4 and IPv6 addresses
+and networks.
+
+"""
+
+from __future__ import unicode_literals
+
+
+import itertools
+import struct
+
+
+# The following makes it easier for us to script updates of the bundled code and is not part of
+# upstream
+_BUNDLED_METADATA = {"pypi_name": "ipaddress", "version": "1.0.22"}
+
+__version__ = "1.0.22"
+
+# Compatibility functions
+_compat_int_types = (int,)
+try:
+ _compat_int_types = (int, long)
+except NameError:
+ pass
+try:
+ _compat_str = unicode
+except NameError:
+ _compat_str = str
+ assert bytes != str
+if b"\0"[0] == 0: # Python 3 semantics
+
+ def _compat_bytes_to_byte_vals(byt):
+ return byt
+
+
+else:
+
+ def _compat_bytes_to_byte_vals(byt):
+ return [struct.unpack(b"!B", b)[0] for b in byt]
+
+
+try:
+ _compat_int_from_byte_vals = int.from_bytes
+except AttributeError:
+
+ def _compat_int_from_byte_vals(bytvals, endianess):
+ assert endianess == "big"
+ res = 0
+ for bv in bytvals:
+ assert isinstance(bv, _compat_int_types)
+ res = (res << 8) + bv
+ return res
+
+
+def _compat_to_bytes(intval, length, endianess):
+ assert isinstance(intval, _compat_int_types)
+ assert endianess == "big"
+ if length == 4:
+ if intval < 0 or intval >= 2 ** 32:
+ raise struct.error("integer out of range for 'I' format code")
+ return struct.pack(b"!I", intval)
+ elif length == 16:
+ if intval < 0 or intval >= 2 ** 128:
+ raise struct.error("integer out of range for 'QQ' format code")
+ return struct.pack(b"!QQ", intval >> 64, intval & 0xFFFFFFFFFFFFFFFF)
+ else:
+ raise NotImplementedError()
+
+
+if hasattr(int, "bit_length"):
+ # Not int.bit_length , since that won't work in 2.7 where long exists
+ def _compat_bit_length(i):
+ return i.bit_length()
+
+
+else:
+
+ def _compat_bit_length(i):
+ for res in itertools.count():
+ if i >> res == 0:
+ return res
+
+
+def _compat_range(start, end, step=1):
+ assert step > 0
+ i = start
+ while i < end:
+ yield i
+ i += step
+
+
+class _TotalOrderingMixin(object):
+ __slots__ = ()
+
+ # Helper that derives the other comparison operations from
+ # __lt__ and __eq__
+ # We avoid functools.total_ordering because it doesn't handle
+ # NotImplemented correctly yet (http://bugs.python.org/issue10042)
+ def __eq__(self, other):
+ raise NotImplementedError
+
+ def __ne__(self, other):
+ equal = self.__eq__(other)
+ if equal is NotImplemented:
+ return NotImplemented
+ return not equal
+
+ def __lt__(self, other):
+ raise NotImplementedError
+
+ def __le__(self, other):
+ less = self.__lt__(other)
+ if less is NotImplemented or not less:
+ return self.__eq__(other)
+ return less
+
+ def __gt__(self, other):
+ less = self.__lt__(other)
+ if less is NotImplemented:
+ return NotImplemented
+ equal = self.__eq__(other)
+ if equal is NotImplemented:
+ return NotImplemented
+ return not (less or equal)
+
+ def __ge__(self, other):
+ less = self.__lt__(other)
+ if less is NotImplemented:
+ return NotImplemented
+ return not less
+
+
+IPV4LENGTH = 32
+IPV6LENGTH = 128
+
+
+class AddressValueError(ValueError):
+ """A Value Error related to the address."""
+
+
+class NetmaskValueError(ValueError):
+ """A Value Error related to the netmask."""
+
+
+def ip_address(address):
+ """Take an IP string/int and return an object of the correct type.
+
+ Args:
+ address: A string or integer, the IP address. Either IPv4 or
+ IPv6 addresses may be supplied; integers less than 2**32 will
+ be considered to be IPv4 by default.
+
+ Returns:
+ An IPv4Address or IPv6Address object.
+
+ Raises:
+ ValueError: if the *address* passed isn't either a v4 or a v6
+ address
+
+ """
+ try:
+ return IPv4Address(address)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ try:
+ return IPv6Address(address)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ if isinstance(address, bytes):
+ raise AddressValueError(
+ "%r does not appear to be an IPv4 or IPv6 address. "
+ "Did you pass in a bytes (str in Python 2) instead of"
+ " a unicode object?" % address
+ )
+
+ raise ValueError(
+ "%r does not appear to be an IPv4 or IPv6 address" % address
+ )
+
+
+def ip_network(address, strict=True):
+ """Take an IP string/int and return an object of the correct type.
+
+ Args:
+ address: A string or integer, the IP network. Either IPv4 or
+ IPv6 networks may be supplied; integers less than 2**32 will
+ be considered to be IPv4 by default.
+
+ Returns:
+ An IPv4Network or IPv6Network object.
+
+ Raises:
+ ValueError: if the string passed isn't either a v4 or a v6
+ address. Or if the network has host bits set.
+
+ """
+ try:
+ return IPv4Network(address, strict)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ try:
+ return IPv6Network(address, strict)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ if isinstance(address, bytes):
+ raise AddressValueError(
+ "%r does not appear to be an IPv4 or IPv6 network. "
+ "Did you pass in a bytes (str in Python 2) instead of"
+ " a unicode object?" % address
+ )
+
+ raise ValueError(
+ "%r does not appear to be an IPv4 or IPv6 network" % address
+ )
+
+
+def ip_interface(address):
+ """Take an IP string/int and return an object of the correct type.
+
+ Args:
+ address: A string or integer, the IP address. Either IPv4 or
+ IPv6 addresses may be supplied; integers less than 2**32 will
+ be considered to be IPv4 by default.
+
+ Returns:
+ An IPv4Interface or IPv6Interface object.
+
+ Raises:
+ ValueError: if the string passed isn't either a v4 or a v6
+ address.
+
+ Notes:
+ The IPv?Interface classes describe an Address on a particular
+ Network, so they're basically a combination of both the Address
+ and Network classes.
+
+ """
+ try:
+ return IPv4Interface(address)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ try:
+ return IPv6Interface(address)
+ except (AddressValueError, NetmaskValueError):
+ pass
+
+ raise ValueError(
+ "%r does not appear to be an IPv4 or IPv6 interface" % address
+ )
+
+
+def v4_int_to_packed(address):
+ """Represent an address as 4 packed bytes in network (big-endian) order.
+
+ Args:
+ address: An integer representation of an IPv4 IP address.
+
+ Returns:
+ The integer address packed as 4 bytes in network (big-endian) order.
+
+ Raises:
+ ValueError: If the integer is negative or too large to be an
+ IPv4 IP address.
+
+ """
+ try:
+ return _compat_to_bytes(address, 4, "big")
+ except (struct.error, OverflowError):
+ raise ValueError("Address negative or too large for IPv4")
+
+
+def v6_int_to_packed(address):
+ """Represent an address as 16 packed bytes in network (big-endian) order.
+
+ Args:
+ address: An integer representation of an IPv6 IP address.
+
+ Returns:
+ The integer address packed as 16 bytes in network (big-endian) order.
+
+ """
+ try:
+ return _compat_to_bytes(address, 16, "big")
+ except (struct.error, OverflowError):
+ raise ValueError("Address negative or too large for IPv6")
+
+
+def _split_optional_netmask(address):
+ """Helper to split the netmask and raise AddressValueError if needed"""
+ addr = _compat_str(address).split("/")
+ if len(addr) > 2:
+ raise AddressValueError("Only one '/' permitted in %r" % address)
+ return addr
+
+
+def _find_address_range(addresses):
+ """Find a sequence of sorted deduplicated IPv#Address.
+
+ Args:
+ addresses: a list of IPv#Address objects.
+
+ Yields:
+ A tuple containing the first and last IP addresses in the sequence.
+
+ """
+ it = iter(addresses)
+ first = last = next(it) # pylint: disable=stop-iteration-return
+ for ip in it:
+ if ip._ip != last._ip + 1:
+ yield first, last
+ first = ip
+ last = ip
+ yield first, last
+
+
+def _count_righthand_zero_bits(number, bits):
+ """Count the number of zero bits on the right hand side.
+
+ Args:
+ number: an integer.
+ bits: maximum number of bits to count.
+
+ Returns:
+ The number of zero bits on the right hand side of the number.
+
+ """
+ if number == 0:
+ return bits
+ return min(bits, _compat_bit_length(~number & (number - 1)))
+
+
+def summarize_address_range(first, last):
+ """Summarize a network range given the first and last IP addresses.
+
+ Example:
+ >>> list(summarize_address_range(IPv4Address('192.0.2.0'),
+ ... IPv4Address('192.0.2.130')))
+ ... #doctest: +NORMALIZE_WHITESPACE
+ [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'),
+ IPv4Network('192.0.2.130/32')]
+
+ Args:
+ first: the first IPv4Address or IPv6Address in the range.
+ last: the last IPv4Address or IPv6Address in the range.
+
+ Returns:
+ An iterator of the summarized IPv(4|6) network objects.
+
+ Raise:
+ TypeError:
+ If the first and last objects are not IP addresses.
+ If the first and last objects are not the same version.
+ ValueError:
+ If the last object is not greater than the first.
+ If the version of the first address is not 4 or 6.
+
+ """
+ if not (
+ isinstance(first, _BaseAddress) and isinstance(last, _BaseAddress)
+ ):
+ raise TypeError("first and last must be IP addresses, not networks")
+ if first.version != last.version:
+ raise TypeError(
+ "%s and %s are not of the same version" % (first, last)
+ )
+ if first > last:
+ raise ValueError("last IP address must be greater than first")
+
+ if first.version == 4:
+ ip = IPv4Network
+ elif first.version == 6:
+ ip = IPv6Network
+ else:
+ raise ValueError("unknown IP version")
+
+ ip_bits = first._max_prefixlen
+ first_int = first._ip
+ last_int = last._ip
+ while first_int <= last_int:
+ nbits = min(
+ _count_righthand_zero_bits(first_int, ip_bits),
+ _compat_bit_length(last_int - first_int + 1) - 1,
+ )
+ net = ip((first_int, ip_bits - nbits))
+ yield net
+ first_int += 1 << nbits
+ if first_int - 1 == ip._ALL_ONES:
+ break
+
+
+def _collapse_addresses_internal(addresses):
+ """Loops through the addresses, collapsing concurrent netblocks.
+
+ Example:
+
+ ip1 = IPv4Network('192.0.2.0/26')
+ ip2 = IPv4Network('192.0.2.64/26')
+ ip3 = IPv4Network('192.0.2.128/26')
+ ip4 = IPv4Network('192.0.2.192/26')
+
+ _collapse_addresses_internal([ip1, ip2, ip3, ip4]) ->
+ [IPv4Network('192.0.2.0/24')]
+
+ This shouldn't be called directly; it is called via
+ collapse_addresses([]).
+
+ Args:
+ addresses: A list of IPv4Network's or IPv6Network's
+
+ Returns:
+ A list of IPv4Network's or IPv6Network's depending on what we were
+ passed.
+
+ """
+ # First merge
+ to_merge = list(addresses)
+ subnets = {}
+ while to_merge:
+ net = to_merge.pop()
+ supernet = net.supernet()
+ existing = subnets.get(supernet)
+ if existing is None:
+ subnets[supernet] = net
+ elif existing != net:
+ # Merge consecutive subnets
+ del subnets[supernet]
+ to_merge.append(supernet)
+ # Then iterate over resulting networks, skipping subsumed subnets
+ last = None
+ for net in sorted(subnets.values()):
+ if last is not None:
+ # Since they are sorted,
+ # last.network_address <= net.network_address is a given.
+ if last.broadcast_address >= net.broadcast_address:
+ continue
+ yield net
+ last = net
+
+
+def collapse_addresses(addresses):
+ """Collapse a list of IP objects.
+
+ Example:
+ collapse_addresses([IPv4Network('192.0.2.0/25'),
+ IPv4Network('192.0.2.128/25')]) ->
+ [IPv4Network('192.0.2.0/24')]
+
+ Args:
+ addresses: An iterator of IPv4Network or IPv6Network objects.
+
+ Returns:
+ An iterator of the collapsed IPv(4|6)Network objects.
+
+ Raises:
+ TypeError: If passed a list of mixed version objects.
+
+ """
+ addrs = []
+ ips = []
+ nets = []
+
+ # split IP addresses and networks
+ for ip in addresses:
+ if isinstance(ip, _BaseAddress):
+ if ips and ips[-1]._version != ip._version:
+ raise TypeError(
+ "%s and %s are not of the same version" % (ip, ips[-1])
+ )
+ ips.append(ip)
+ elif ip._prefixlen == ip._max_prefixlen:
+ if ips and ips[-1]._version != ip._version:
+ raise TypeError(
+ "%s and %s are not of the same version" % (ip, ips[-1])
+ )
+ try:
+ ips.append(ip.ip)
+ except AttributeError:
+ ips.append(ip.network_address)
+ else:
+ if nets and nets[-1]._version != ip._version:
+ raise TypeError(
+ "%s and %s are not of the same version" % (ip, nets[-1])
+ )
+ nets.append(ip)
+
+ # sort and dedup
+ ips = sorted(set(ips))
+
+ # find consecutive address ranges in the sorted sequence and summarize them
+ if ips:
+ for first, last in _find_address_range(ips):
+ addrs.extend(summarize_address_range(first, last))
+
+ return _collapse_addresses_internal(addrs + nets)
+
+
+def get_mixed_type_key(obj):
+ """Return a key suitable for sorting between networks and addresses.
+
+ Address and Network objects are not sortable by default; they're
+ fundamentally different so the expression
+
+ IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24')
+
+ doesn't make any sense. There are some times however, where you may wish
+ to have ipaddress sort these for you anyway. If you need to do this, you
+ can use this function as the key= argument to sorted().
+
+ Args:
+ obj: either a Network or Address object.
+ Returns:
+ appropriate key.
+
+ """
+ if isinstance(obj, _BaseNetwork):
+ return obj._get_networks_key()
+ elif isinstance(obj, _BaseAddress):
+ return obj._get_address_key()
+ return NotImplemented
+
+
+class _IPAddressBase(_TotalOrderingMixin):
+
+ """The mother class."""
+
+ __slots__ = ()
+
+ @property
+ def exploded(self):
+ """Return the longhand version of the IP address as a string."""
+ return self._explode_shorthand_ip_string()
+
+ @property
+ def compressed(self):
+ """Return the shorthand version of the IP address as a string."""
+ return _compat_str(self)
+
+ @property
+ def reverse_pointer(self):
+ """The name of the reverse DNS pointer for the IP address, e.g.:
+ >>> ipaddress.ip_address("127.0.0.1").reverse_pointer
+ '1.0.0.127.in-addr.arpa'
+ >>> ipaddress.ip_address("2001:db8::1").reverse_pointer
+ '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa'
+
+ """
+ return self._reverse_pointer()
+
+ @property
+ def version(self):
+ msg = "%200s has no version specified" % (type(self),)
+ raise NotImplementedError(msg)
+
+ def _check_int_address(self, address):
+ if address < 0:
+ msg = "%d (< 0) is not permitted as an IPv%d address"
+ raise AddressValueError(msg % (address, self._version))
+ if address > self._ALL_ONES:
+ msg = "%d (>= 2**%d) is not permitted as an IPv%d address"
+ raise AddressValueError(
+ msg % (address, self._max_prefixlen, self._version)
+ )
+
+ def _check_packed_address(self, address, expected_len):
+ address_len = len(address)
+ if address_len != expected_len:
+ msg = (
+ "%r (len %d != %d) is not permitted as an IPv%d address. "
+ "Did you pass in a bytes (str in Python 2) instead of"
+ " a unicode object?"
+ )
+ raise AddressValueError(
+ msg % (address, address_len, expected_len, self._version)
+ )
+
+ @classmethod
+ def _ip_int_from_prefix(cls, prefixlen):
+ """Turn the prefix length into a bitwise netmask
+
+ Args:
+ prefixlen: An integer, the prefix length.
+
+ Returns:
+ An integer.
+
+ """
+ return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen)
+
+ @classmethod
+ def _prefix_from_ip_int(cls, ip_int):
+ """Return prefix length from the bitwise netmask.
+
+ Args:
+ ip_int: An integer, the netmask in expanded bitwise format
+
+ Returns:
+ An integer, the prefix length.
+
+ Raises:
+ ValueError: If the input intermingles zeroes & ones
+ """
+ trailing_zeroes = _count_righthand_zero_bits(
+ ip_int, cls._max_prefixlen
+ )
+ prefixlen = cls._max_prefixlen - trailing_zeroes
+ leading_ones = ip_int >> trailing_zeroes
+ all_ones = (1 << prefixlen) - 1
+ if leading_ones != all_ones:
+ byteslen = cls._max_prefixlen // 8
+ details = _compat_to_bytes(ip_int, byteslen, "big")
+ msg = "Netmask pattern %r mixes zeroes & ones"
+ raise ValueError(msg % details)
+ return prefixlen
+
+ @classmethod
+ def _report_invalid_netmask(cls, netmask_str):
+ msg = "%r is not a valid netmask" % netmask_str
+ raise NetmaskValueError(msg)
+
+ @classmethod
+ def _prefix_from_prefix_string(cls, prefixlen_str):
+ """Return prefix length from a numeric string
+
+ Args:
+ prefixlen_str: The string to be converted
+
+ Returns:
+ An integer, the prefix length.
+
+ Raises:
+ NetmaskValueError: If the input is not a valid netmask
+ """
+ # int allows a leading +/- as well as surrounding whitespace,
+ # so we ensure that isn't the case
+ if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str):
+ cls._report_invalid_netmask(prefixlen_str)
+ try:
+ prefixlen = int(prefixlen_str)
+ except ValueError:
+ cls._report_invalid_netmask(prefixlen_str)
+ if not (0 <= prefixlen <= cls._max_prefixlen):
+ cls._report_invalid_netmask(prefixlen_str)
+ return prefixlen
+
+ @classmethod
+ def _prefix_from_ip_string(cls, ip_str):
+ """Turn a netmask/hostmask string into a prefix length
+
+ Args:
+ ip_str: The netmask/hostmask to be converted
+
+ Returns:
+ An integer, the prefix length.
+
+ Raises:
+ NetmaskValueError: If the input is not a valid netmask/hostmask
+ """
+ # Parse the netmask/hostmask like an IP address.
+ try:
+ ip_int = cls._ip_int_from_string(ip_str)
+ except AddressValueError:
+ cls._report_invalid_netmask(ip_str)
+
+ # Try matching a netmask (this would be /1*0*/ as a bitwise regexp).
+ # Note that the two ambiguous cases (all-ones and all-zeroes) are
+ # treated as netmasks.
+ try:
+ return cls._prefix_from_ip_int(ip_int)
+ except ValueError:
+ pass
+
+ # Invert the bits, and try matching a /0+1+/ hostmask instead.
+ ip_int ^= cls._ALL_ONES
+ try:
+ return cls._prefix_from_ip_int(ip_int)
+ except ValueError:
+ cls._report_invalid_netmask(ip_str)
+
+ def __reduce__(self):
+ return self.__class__, (_compat_str(self),)
+
+
+class _BaseAddress(_IPAddressBase):
+
+ """A generic IP object.
+
+ This IP class contains the version independent methods which are
+ used by single IP addresses.
+ """
+
+ __slots__ = ()
+
+ def __int__(self):
+ return self._ip
+
+ def __eq__(self, other):
+ try:
+ return self._ip == other._ip and self._version == other._version
+ except AttributeError:
+ return NotImplemented
+
+ def __lt__(self, other):
+ if not isinstance(other, _IPAddressBase):
+ return NotImplemented
+ if not isinstance(other, _BaseAddress):
+ raise TypeError(
+ "%s and %s are not of the same type" % (self, other)
+ )
+ if self._version != other._version:
+ raise TypeError(
+ "%s and %s are not of the same version" % (self, other)
+ )
+ if self._ip != other._ip:
+ return self._ip < other._ip
+ return False
+
+ # Shorthand for Integer addition and subtraction. This is not
+ # meant to ever support addition/subtraction of addresses.
+ def __add__(self, other):
+ if not isinstance(other, _compat_int_types):
+ return NotImplemented
+ return self.__class__(int(self) + other)
+
+ def __sub__(self, other):
+ if not isinstance(other, _compat_int_types):
+ return NotImplemented
+ return self.__class__(int(self) - other)
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, _compat_str(self))
+
+ def __str__(self):
+ return _compat_str(self._string_from_ip_int(self._ip))
+
+ def __hash__(self):
+ return hash(hex(int(self._ip)))
+
+ def _get_address_key(self):
+ return (self._version, self)
+
+ def __reduce__(self):
+ return self.__class__, (self._ip,)
+
+
+class _BaseNetwork(_IPAddressBase):
+
+ """A generic IP network object.
+
+ This IP class contains the version independent methods which are
+ used by networks.
+
+ """
+
+ def __init__(self, address):
+ self._cache = {}
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, _compat_str(self))
+
+ def __str__(self):
+ return "%s/%d" % (self.network_address, self.prefixlen)
+
+ def hosts(self):
+ """Generate Iterator over usable hosts in a network.
+
+ This is like __iter__ except it doesn't return the network
+ or broadcast addresses.
+
+ """
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ for x in _compat_range(network + 1, broadcast):
+ yield self._address_class(x)
+
+ def __iter__(self):
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ for x in _compat_range(network, broadcast + 1):
+ yield self._address_class(x)
+
+ def __getitem__(self, n):
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ if n >= 0:
+ if network + n > broadcast:
+ raise IndexError("address out of range")
+ return self._address_class(network + n)
+ else:
+ n += 1
+ if broadcast + n < network:
+ raise IndexError("address out of range")
+ return self._address_class(broadcast + n)
+
+ def __lt__(self, other):
+ if not isinstance(other, _IPAddressBase):
+ return NotImplemented
+ if not isinstance(other, _BaseNetwork):
+ raise TypeError(
+ "%s and %s are not of the same type" % (self, other)
+ )
+ if self._version != other._version:
+ raise TypeError(
+ "%s and %s are not of the same version" % (self, other)
+ )
+ if self.network_address != other.network_address:
+ return self.network_address < other.network_address
+ if self.netmask != other.netmask:
+ return self.netmask < other.netmask
+ return False
+
+ def __eq__(self, other):
+ try:
+ return (
+ self._version == other._version
+ and self.network_address == other.network_address
+ and int(self.netmask) == int(other.netmask)
+ )
+ except AttributeError:
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(int(self.network_address) ^ int(self.netmask))
+
+ def __contains__(self, other):
+ # always false if one is v4 and the other is v6.
+ if self._version != other._version:
+ return False
+ # dealing with another network.
+ if isinstance(other, _BaseNetwork):
+ return False
+ # dealing with another address
+ else:
+ # address
+ return (
+ int(self.network_address)
+ <= int(other._ip)
+ <= int(self.broadcast_address)
+ )
+
+ def overlaps(self, other):
+ """Tell if self is partly contained in other."""
+ return self.network_address in other or (
+ self.broadcast_address in other
+ or (
+ other.network_address in self
+ or (other.broadcast_address in self)
+ )
+ )
+
+ @property
+ def broadcast_address(self):
+ x = self._cache.get("broadcast_address")
+ if x is None:
+ x = self._address_class(
+ int(self.network_address) | int(self.hostmask)
+ )
+ self._cache["broadcast_address"] = x
+ return x
+
+ @property
+ def hostmask(self):
+ x = self._cache.get("hostmask")
+ if x is None:
+ x = self._address_class(int(self.netmask) ^ self._ALL_ONES)
+ self._cache["hostmask"] = x
+ return x
+
+ @property
+ def with_prefixlen(self):
+ return "%s/%d" % (self.network_address, self._prefixlen)
+
+ @property
+ def with_netmask(self):
+ return "%s/%s" % (self.network_address, self.netmask)
+
+ @property
+ def with_hostmask(self):
+ return "%s/%s" % (self.network_address, self.hostmask)
+
+ @property
+ def num_addresses(self):
+ """Number of hosts in the current subnet."""
+ return int(self.broadcast_address) - int(self.network_address) + 1
+
+ @property
+ def _address_class(self):
+ # Returning bare address objects (rather than interfaces) allows for
+ # more consistent behaviour across the network address, broadcast
+ # address and individual host addresses.
+ msg = "%200s has no associated address class" % (type(self),)
+ raise NotImplementedError(msg)
+
+ @property
+ def prefixlen(self):
+ return self._prefixlen
+
+ def address_exclude(self, other):
+ """Remove an address from a larger block.
+
+ For example:
+
+ addr1 = ip_network('192.0.2.0/28')
+ addr2 = ip_network('192.0.2.1/32')
+ list(addr1.address_exclude(addr2)) =
+ [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'),
+ IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')]
+
+ or IPv6:
+
+ addr1 = ip_network('2001:db8::1/32')
+ addr2 = ip_network('2001:db8::1/128')
+ list(addr1.address_exclude(addr2)) =
+ [ip_network('2001:db8::1/128'),
+ ip_network('2001:db8::2/127'),
+ ip_network('2001:db8::4/126'),
+ ip_network('2001:db8::8/125'),
+ ...
+ ip_network('2001:db8:8000::/33')]
+
+ Args:
+ other: An IPv4Network or IPv6Network object of the same type.
+
+ Returns:
+ An iterator of the IPv(4|6)Network objects which is self
+ minus other.
+
+ Raises:
+ TypeError: If self and other are of differing address
+ versions, or if other is not a network object.
+ ValueError: If other is not completely contained by self.
+
+ """
+ if not self._version == other._version:
+ raise TypeError(
+ "%s and %s are not of the same version" % (self, other)
+ )
+
+ if not isinstance(other, _BaseNetwork):
+ raise TypeError("%s is not a network object" % other)
+
+ if not other.subnet_of(self):
+ raise ValueError("%s not contained in %s" % (other, self))
+ if other == self:
+ return
+
+ # Make sure we're comparing the network of other.
+ other = other.__class__(
+ "%s/%s" % (other.network_address, other.prefixlen)
+ )
+
+ s1, s2 = self.subnets()
+ while s1 != other and s2 != other:
+ if other.subnet_of(s1):
+ yield s2
+ s1, s2 = s1.subnets()
+ elif other.subnet_of(s2):
+ yield s1
+ s1, s2 = s2.subnets()
+ else:
+ # If we got here, there's a bug somewhere.
+ raise AssertionError(
+ "Error performing exclusion: "
+ "s1: %s s2: %s other: %s" % (s1, s2, other)
+ )
+ if s1 == other:
+ yield s2
+ elif s2 == other:
+ yield s1
+ else:
+ # If we got here, there's a bug somewhere.
+ raise AssertionError(
+ "Error performing exclusion: "
+ "s1: %s s2: %s other: %s" % (s1, s2, other)
+ )
+
+ def compare_networks(self, other):
+ """Compare two IP objects.
+
+ This is only concerned about the comparison of the integer
+ representation of the network addresses. This means that the
+ host bits aren't considered at all in this method. If you want
+ to compare host bits, you can easily enough do a
+ 'HostA._ip < HostB._ip'
+
+ Args:
+ other: An IP object.
+
+ Returns:
+ If the IP versions of self and other are the same, returns:
+
+ -1 if self < other:
+ eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25')
+ IPv6Network('2001:db8::1000/124') <
+ IPv6Network('2001:db8::2000/124')
+ 0 if self == other
+ eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24')
+ IPv6Network('2001:db8::1000/124') ==
+ IPv6Network('2001:db8::1000/124')
+ 1 if self > other
+ eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25')
+ IPv6Network('2001:db8::2000/124') >
+ IPv6Network('2001:db8::1000/124')
+
+ Raises:
+ TypeError if the IP versions are different.
+
+ """
+ # does this need to raise a ValueError?
+ if self._version != other._version:
+ raise TypeError(
+ "%s and %s are not of the same type" % (self, other)
+ )
+ # self._version == other._version below here:
+ if self.network_address < other.network_address:
+ return -1
+ if self.network_address > other.network_address:
+ return 1
+ # self.network_address == other.network_address below here:
+ if self.netmask < other.netmask:
+ return -1
+ if self.netmask > other.netmask:
+ return 1
+ return 0
+
+ def _get_networks_key(self):
+ """Network-only key function.
+
+ Returns an object that identifies this address' network and
+ netmask. This function is a suitable "key" argument for sorted()
+ and list.sort().
+
+ """
+ return (self._version, self.network_address, self.netmask)
+
+ def subnets(self, prefixlen_diff=1, new_prefix=None):
+ """The subnets which join to make the current subnet.
+
+ In the case that self contains only one IP
+ (self._prefixlen == 32 for IPv4 or self._prefixlen == 128
+ for IPv6), yield an iterator with just ourself.
+
+ Args:
+ prefixlen_diff: An integer, the amount the prefix length
+ should be increased by. This should not be set if
+ new_prefix is also set.
+ new_prefix: The desired new prefix length. This must be a
+ larger number (smaller prefix) than the existing prefix.
+ This should not be set if prefixlen_diff is also set.
+
+ Returns:
+ An iterator of IPv(4|6) objects.
+
+ Raises:
+ ValueError: The prefixlen_diff is too small or too large.
+ OR
+ prefixlen_diff and new_prefix are both set or new_prefix
+ is a smaller number than the current prefix (smaller
+ number means a larger network)
+
+ """
+ if self._prefixlen == self._max_prefixlen:
+ yield self
+ return
+
+ if new_prefix is not None:
+ if new_prefix < self._prefixlen:
+ raise ValueError("new prefix must be longer")
+ if prefixlen_diff != 1:
+ raise ValueError("cannot set prefixlen_diff and new_prefix")
+ prefixlen_diff = new_prefix - self._prefixlen
+
+ if prefixlen_diff < 0:
+ raise ValueError("prefix length diff must be > 0")
+ new_prefixlen = self._prefixlen + prefixlen_diff
+
+ if new_prefixlen > self._max_prefixlen:
+ raise ValueError(
+ "prefix length diff %d is invalid for netblock %s"
+ % (new_prefixlen, self)
+ )
+
+ start = int(self.network_address)
+ end = int(self.broadcast_address) + 1
+ step = (int(self.hostmask) + 1) >> prefixlen_diff
+ for new_addr in _compat_range(start, end, step):
+ current = self.__class__((new_addr, new_prefixlen))
+ yield current
+
+ def supernet(self, prefixlen_diff=1, new_prefix=None):
+ """The supernet containing the current network.
+
+ Args:
+ prefixlen_diff: An integer, the amount the prefix length of
+ the network should be decreased by. For example, given a
+ /24 network and a prefixlen_diff of 3, a supernet with a
+ /21 netmask is returned.
+
+ Returns:
+ An IPv4 network object.
+
+ Raises:
+ ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have
+ a negative prefix length.
+ OR
+ If prefixlen_diff and new_prefix are both set or new_prefix is a
+ larger number than the current prefix (larger number means a
+ smaller network)
+
+ """
+ if self._prefixlen == 0:
+ return self
+
+ if new_prefix is not None:
+ if new_prefix > self._prefixlen:
+ raise ValueError("new prefix must be shorter")
+ if prefixlen_diff != 1:
+ raise ValueError("cannot set prefixlen_diff and new_prefix")
+ prefixlen_diff = self._prefixlen - new_prefix
+
+ new_prefixlen = self.prefixlen - prefixlen_diff
+ if new_prefixlen < 0:
+ raise ValueError(
+ "current prefixlen is %d, cannot have a prefixlen_diff of %d"
+ % (self.prefixlen, prefixlen_diff)
+ )
+ return self.__class__(
+ (
+ int(self.network_address)
+ & (int(self.netmask) << prefixlen_diff),
+ new_prefixlen,
+ )
+ )
+
+ @property
+ def is_multicast(self):
+ """Test if the address is reserved for multicast use.
+
+ Returns:
+ A boolean, True if the address is a multicast address.
+ See RFC 2373 2.7 for details.
+
+ """
+ return (
+ self.network_address.is_multicast
+ and self.broadcast_address.is_multicast
+ )
+
+ @staticmethod
+ def _is_subnet_of(a, b):
+ try:
+ # Always false if one is v4 and the other is v6.
+ if a._version != b._version:
+ raise TypeError(
+ "%s and %s are not of the same version" % (a, b)
+ )
+ return (
+ b.network_address <= a.network_address
+ and b.broadcast_address >= a.broadcast_address
+ )
+ except AttributeError:
+ raise TypeError(
+ "Unable to test subnet containment "
+ "between %s and %s" % (a, b)
+ )
+
+ def subnet_of(self, other):
+ """Return True if this network is a subnet of other."""
+ return self._is_subnet_of(self, other)
+
+ def supernet_of(self, other):
+ """Return True if this network is a supernet of other."""
+ return self._is_subnet_of(other, self)
+
+ @property
+ def is_reserved(self):
+ """Test if the address is otherwise IETF reserved.
+
+ Returns:
+ A boolean, True if the address is within one of the
+ reserved IPv6 Network ranges.
+
+ """
+ return (
+ self.network_address.is_reserved
+ and self.broadcast_address.is_reserved
+ )
+
+ @property
+ def is_link_local(self):
+ """Test if the address is reserved for link-local.
+
+ Returns:
+ A boolean, True if the address is reserved per RFC 4291.
+
+ """
+ return (
+ self.network_address.is_link_local
+ and self.broadcast_address.is_link_local
+ )
+
+ @property
+ def is_private(self):
+ """Test if this address is allocated for private networks.
+
+ Returns:
+ A boolean, True if the address is reserved per
+ iana-ipv4-special-registry or iana-ipv6-special-registry.
+
+ """
+ return (
+ self.network_address.is_private
+ and self.broadcast_address.is_private
+ )
+
+ @property
+ def is_global(self):
+ """Test if this address is allocated for public networks.
+
+ Returns:
+ A boolean, True if the address is not reserved per
+ iana-ipv4-special-registry or iana-ipv6-special-registry.
+
+ """
+ return not self.is_private
+
+ @property
+ def is_unspecified(self):
+ """Test if the address is unspecified.
+
+ Returns:
+ A boolean, True if this is the unspecified address as defined in
+ RFC 2373 2.5.2.
+
+ """
+ return (
+ self.network_address.is_unspecified
+ and self.broadcast_address.is_unspecified
+ )
+
+ @property
+ def is_loopback(self):
+ """Test if the address is a loopback address.
+
+ Returns:
+ A boolean, True if the address is a loopback address as defined in
+ RFC 2373 2.5.3.
+
+ """
+ return (
+ self.network_address.is_loopback
+ and self.broadcast_address.is_loopback
+ )
+
+
+class _BaseV4(object):
+
+ """Base IPv4 object.
+
+ The following methods are used by IPv4 objects in both single IP
+ addresses and networks.
+
+ """
+
+ __slots__ = ()
+ _version = 4
+ # Equivalent to 255.255.255.255 or 32 bits of 1's.
+ _ALL_ONES = (2 ** IPV4LENGTH) - 1
+ _DECIMAL_DIGITS = frozenset("0123456789")
+
+ # the valid octets for host and netmasks. only useful for IPv4.
+ _valid_mask_octets = frozenset([255, 254, 252, 248, 240, 224, 192, 128, 0])
+
+ _max_prefixlen = IPV4LENGTH
+ # There are only a handful of valid v4 netmasks, so we cache them all
+ # when constructed (see _make_netmask()).
+ _netmask_cache = {}
+
+ def _explode_shorthand_ip_string(self):
+ return _compat_str(self)
+
+ @classmethod
+ def _make_netmask(cls, arg):
+ """Make a (netmask, prefix_len) tuple from the given argument.
+
+ Argument can be:
+ - an integer (the prefix length)
+ - a string representing the prefix length (e.g. "24")
+ - a string representing the prefix netmask (e.g. "255.255.255.0")
+ """
+ if arg not in cls._netmask_cache:
+ if isinstance(arg, _compat_int_types):
+ prefixlen = arg
+ else:
+ try:
+ # Check for a netmask in prefix length form
+ prefixlen = cls._prefix_from_prefix_string(arg)
+ except NetmaskValueError:
+ # Check for a netmask or hostmask in dotted-quad form.
+ # This may raise NetmaskValueError.
+ prefixlen = cls._prefix_from_ip_string(arg)
+ netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen))
+ cls._netmask_cache[arg] = netmask, prefixlen
+ return cls._netmask_cache[arg]
+
+ @classmethod
+ def _ip_int_from_string(cls, ip_str):
+ """Turn the given IP string into an integer for comparison.
+
+ Args:
+ ip_str: A string, the IP ip_str.
+
+ Returns:
+ The IP ip_str as an integer.
+
+ Raises:
+ AddressValueError: if ip_str isn't a valid IPv4 Address.
+
+ """
+ if not ip_str:
+ raise AddressValueError("Address cannot be empty")
+
+ octets = ip_str.split(".")
+ if len(octets) != 4:
+ raise AddressValueError("Expected 4 octets in %r" % ip_str)
+
+ try:
+ return _compat_int_from_byte_vals(
+ map(cls._parse_octet, octets), "big"
+ )
+ except ValueError as exc:
+ raise AddressValueError("%s in %r" % (exc, ip_str))
+
+ @classmethod
+ def _parse_octet(cls, octet_str):
+ """Convert a decimal octet into an integer.
+
+ Args:
+ octet_str: A string, the number to parse.
+
+ Returns:
+ The octet as an integer.
+
+ Raises:
+ ValueError: if the octet isn't strictly a decimal from [0..255].
+
+ """
+ if not octet_str:
+ raise ValueError("Empty octet not permitted")
+ # Whitelist the characters, since int() allows a lot of bizarre stuff.
+ if not cls._DECIMAL_DIGITS.issuperset(octet_str):
+ msg = "Only decimal digits permitted in %r"
+ raise ValueError(msg % octet_str)
+ # We do the length check second, since the invalid character error
+ # is likely to be more informative for the user
+ if len(octet_str) > 3:
+ msg = "At most 3 characters permitted in %r"
+ raise ValueError(msg % octet_str)
+ # Convert to integer (we know digits are legal)
+ octet_int = int(octet_str, 10)
+ # Any octets that look like they *might* be written in octal,
+ # and which don't look exactly the same in both octal and
+ # decimal are rejected as ambiguous
+ if octet_int > 7 and octet_str[0] == "0":
+ msg = "Ambiguous (octal/decimal) value in %r not permitted"
+ raise ValueError(msg % octet_str)
+ if octet_int > 255:
+ raise ValueError("Octet %d (> 255) not permitted" % octet_int)
+ return octet_int
+
+ @classmethod
+ def _string_from_ip_int(cls, ip_int):
+ """Turns a 32-bit integer into dotted decimal notation.
+
+ Args:
+ ip_int: An integer, the IP address.
+
+ Returns:
+ The IP address as a string in dotted decimal notation.
+
+ """
+ return ".".join(
+ _compat_str(
+ struct.unpack(b"!B", b)[0] if isinstance(b, bytes) else b
+ )
+ for b in _compat_to_bytes(ip_int, 4, "big")
+ )
+
+ def _is_hostmask(self, ip_str):
+ """Test if the IP string is a hostmask (rather than a netmask).
+
+ Args:
+ ip_str: A string, the potential hostmask.
+
+ Returns:
+ A boolean, True if the IP string is a hostmask.
+
+ """
+ bits = ip_str.split(".")
+ try:
+ parts = [x for x in map(int, bits) if x in self._valid_mask_octets]
+ except ValueError:
+ return False
+ if len(parts) != len(bits):
+ return False
+ if parts[0] < parts[-1]:
+ return True
+ return False
+
+ def _reverse_pointer(self):
+ """Return the reverse DNS pointer name for the IPv4 address.
+
+ This implements the method described in RFC1035 3.5.
+
+ """
+ reverse_octets = _compat_str(self).split(".")[::-1]
+ return ".".join(reverse_octets) + ".in-addr.arpa"
+
+ @property
+ def max_prefixlen(self):
+ return self._max_prefixlen
+
+ @property
+ def version(self):
+ return self._version
+
+
+class IPv4Address(_BaseV4, _BaseAddress):
+
+ """Represent and manipulate single IPv4 Addresses."""
+
+ __slots__ = ("_ip", "__weakref__")
+
+ def __init__(self, address):
+
+ """
+ Args:
+ address: A string or integer representing the IP
+
+ Additionally, an integer can be passed, so
+ IPv4Address('192.0.2.1') == IPv4Address(3221225985).
+ or, more generally
+ IPv4Address(int(IPv4Address('192.0.2.1'))) ==
+ IPv4Address('192.0.2.1')
+
+ Raises:
+ AddressValueError: If ipaddress isn't a valid IPv4 address.
+
+ """
+ # Efficient constructor from integer.
+ if isinstance(address, _compat_int_types):
+ self._check_int_address(address)
+ self._ip = address
+ return
+
+ # Constructing from a packed address
+ if isinstance(address, bytes):
+ self._check_packed_address(address, 4)
+ bvs = _compat_bytes_to_byte_vals(address)
+ self._ip = _compat_int_from_byte_vals(bvs, "big")
+ return
+
+ # Assume input argument to be string or any object representation
+ # which converts into a formatted IP string.
+ addr_str = _compat_str(address)
+ if "/" in addr_str:
+ raise AddressValueError("Unexpected '/' in %r" % address)
+ self._ip = self._ip_int_from_string(addr_str)
+
+ @property
+ def packed(self):
+ """The binary representation of this address."""
+ return v4_int_to_packed(self._ip)
+
+ @property
+ def is_reserved(self):
+ """Test if the address is otherwise IETF reserved.
+
+ Returns:
+ A boolean, True if the address is within the
+ reserved IPv4 Network range.
+
+ """
+ return self in self._constants._reserved_network
+
+ @property
+ def is_private(self):
+ """Test if this address is allocated for private networks.
+
+ Returns:
+ A boolean, True if the address is reserved per
+ iana-ipv4-special-registry.
+
+ """
+ return any(self in net for net in self._constants._private_networks)
+
+ @property
+ def is_global(self):
+ return (
+ self not in self._constants._public_network and not self.is_private
+ )
+
+ @property
+ def is_multicast(self):
+ """Test if the address is reserved for multicast use.
+
+ Returns:
+ A boolean, True if the address is multicast.
+ See RFC 3171 for details.
+
+ """
+ return self in self._constants._multicast_network
+
+ @property
+ def is_unspecified(self):
+ """Test if the address is unspecified.
+
+ Returns:
+ A boolean, True if this is the unspecified address as defined in
+ RFC 5735 3.
+
+ """
+ return self == self._constants._unspecified_address
+
+ @property
+ def is_loopback(self):
+ """Test if the address is a loopback address.
+
+ Returns:
+ A boolean, True if the address is a loopback per RFC 3330.
+
+ """
+ return self in self._constants._loopback_network
+
+ @property
+ def is_link_local(self):
+ """Test if the address is reserved for link-local.
+
+ Returns:
+ A boolean, True if the address is link-local per RFC 3927.
+
+ """
+ return self in self._constants._linklocal_network
+
+
+class IPv4Interface(IPv4Address):
+ def __init__(self, address):
+ if isinstance(address, (bytes, _compat_int_types)):
+ IPv4Address.__init__(self, address)
+ self.network = IPv4Network(self._ip)
+ self._prefixlen = self._max_prefixlen
+ return
+
+ if isinstance(address, tuple):
+ IPv4Address.__init__(self, address[0])
+ if len(address) > 1:
+ self._prefixlen = int(address[1])
+ else:
+ self._prefixlen = self._max_prefixlen
+
+ self.network = IPv4Network(address, strict=False)
+ self.netmask = self.network.netmask
+ self.hostmask = self.network.hostmask
+ return
+
+ addr = _split_optional_netmask(address)
+ IPv4Address.__init__(self, addr[0])
+
+ self.network = IPv4Network(address, strict=False)
+ self._prefixlen = self.network._prefixlen
+
+ self.netmask = self.network.netmask
+ self.hostmask = self.network.hostmask
+
+ def __str__(self):
+ return "%s/%d" % (
+ self._string_from_ip_int(self._ip),
+ self.network.prefixlen,
+ )
+
+ def __eq__(self, other):
+ address_equal = IPv4Address.__eq__(self, other)
+ if not address_equal or address_equal is NotImplemented:
+ return address_equal
+ try:
+ return self.network == other.network
+ except AttributeError:
+ # An interface with an associated network is NOT the
+ # same as an unassociated address. That's why the hash
+ # takes the extra info into account.
+ return False
+
+ def __lt__(self, other):
+ address_less = IPv4Address.__lt__(self, other)
+ if address_less is NotImplemented:
+ return NotImplemented
+ try:
+ return (
+ self.network < other.network
+ or self.network == other.network
+ and address_less
+ )
+ except AttributeError:
+ # We *do* allow addresses and interfaces to be sorted. The
+ # unassociated address is considered less than all interfaces.
+ return False
+
+ def __hash__(self):
+ return self._ip ^ self._prefixlen ^ int(self.network.network_address)
+
+ __reduce__ = _IPAddressBase.__reduce__
+
+ @property
+ def ip(self):
+ return IPv4Address(self._ip)
+
+ @property
+ def with_prefixlen(self):
+ return "%s/%s" % (self._string_from_ip_int(self._ip), self._prefixlen)
+
+ @property
+ def with_netmask(self):
+ return "%s/%s" % (self._string_from_ip_int(self._ip), self.netmask)
+
+ @property
+ def with_hostmask(self):
+ return "%s/%s" % (self._string_from_ip_int(self._ip), self.hostmask)
+
+
+class IPv4Network(_BaseV4, _BaseNetwork):
+
+ """This class represents and manipulates 32-bit IPv4 network + addresses..
+
+ Attributes: [examples for IPv4Network('192.0.2.0/27')]
+ .network_address: IPv4Address('192.0.2.0')
+ .hostmask: IPv4Address('0.0.0.31')
+ .broadcast_address: IPv4Address('192.0.2.32')
+ .netmask: IPv4Address('255.255.255.224')
+ .prefixlen: 27
+
+ """
+
+ # Class to use when creating address objects
+ _address_class = IPv4Address
+
+ def __init__(self, address, strict=True):
+
+ """Instantiate a new IPv4 network object.
+
+ Args:
+ address: A string or integer representing the IP [& network].
+ '192.0.2.0/24'
+ '192.0.2.0/255.255.255.0'
+ '192.0.0.2/0.0.0.255'
+ are all functionally the same in IPv4. Similarly,
+ '192.0.2.1'
+ '192.0.2.1/255.255.255.255'
+ '192.0.2.1/32'
+ are also functionally equivalent. That is to say, failing to
+ provide a subnetmask will create an object with a mask of /32.
+
+ If the mask (portion after the / in the argument) is given in
+ dotted quad form, it is treated as a netmask if it starts with a
+ non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it
+ starts with a zero field (e.g. 0.255.255.255 == /8), with the
+ single exception of an all-zero mask which is treated as a
+ netmask == /0. If no mask is given, a default of /32 is used.
+
+ Additionally, an integer can be passed, so
+ IPv4Network('192.0.2.1') == IPv4Network(3221225985)
+ or, more generally
+ IPv4Interface(int(IPv4Interface('192.0.2.1'))) ==
+ IPv4Interface('192.0.2.1')
+
+ Raises:
+ AddressValueError: If ipaddress isn't a valid IPv4 address.
+ NetmaskValueError: If the netmask isn't valid for
+ an IPv4 address.
+ ValueError: If strict is True and a network address is not
+ supplied.
+
+ """
+ _BaseNetwork.__init__(self, address)
+
+ # Constructing from a packed address or integer
+ if isinstance(address, (_compat_int_types, bytes)):
+ self.network_address = IPv4Address(address)
+ self.netmask, self._prefixlen = self._make_netmask(
+ self._max_prefixlen
+ )
+ # fixme: address/network test here.
+ return
+
+ if isinstance(address, tuple):
+ if len(address) > 1:
+ arg = address[1]
+ else:
+ # We weren't given an address[1]
+ arg = self._max_prefixlen
+ self.network_address = IPv4Address(address[0])
+ self.netmask, self._prefixlen = self._make_netmask(arg)
+ packed = int(self.network_address)
+ if packed & int(self.netmask) != packed:
+ if strict:
+ raise ValueError("%s has host bits set" % self)
+ else:
+ self.network_address = IPv4Address(
+ packed & int(self.netmask)
+ )
+ return
+
+ # Assume input argument to be string or any object representation
+ # which converts into a formatted IP prefix string.
+ addr = _split_optional_netmask(address)
+ self.network_address = IPv4Address(self._ip_int_from_string(addr[0]))
+
+ if len(addr) == 2:
+ arg = addr[1]
+ else:
+ arg = self._max_prefixlen
+ self.netmask, self._prefixlen = self._make_netmask(arg)
+
+ if strict:
+ if (
+ IPv4Address(int(self.network_address) & int(self.netmask))
+ != self.network_address
+ ):
+ raise ValueError("%s has host bits set" % self)
+ self.network_address = IPv4Address(
+ int(self.network_address) & int(self.netmask)
+ )
+
+ if self._prefixlen == (self._max_prefixlen - 1):
+ self.hosts = self.__iter__
+
+ @property
+ def is_global(self):
+ """Test if this address is allocated for public networks.
+
+ Returns:
+ A boolean, True if the address is not reserved per
+ iana-ipv4-special-registry.
+
+ """
+ return (
+ not (
+ self.network_address in IPv4Network("100.64.0.0/10")
+ and self.broadcast_address in IPv4Network("100.64.0.0/10")
+ )
+ and not self.is_private
+ )
+
+
+class _IPv4Constants(object):
+
+ _linklocal_network = IPv4Network("169.254.0.0/16")
+
+ _loopback_network = IPv4Network("127.0.0.0/8")
+
+ _multicast_network = IPv4Network("224.0.0.0/4")
+
+ _public_network = IPv4Network("100.64.0.0/10")
+
+ _private_networks = [
+ IPv4Network("0.0.0.0/8"),
+ IPv4Network("10.0.0.0/8"),
+ IPv4Network("127.0.0.0/8"),
+ IPv4Network("169.254.0.0/16"),
+ IPv4Network("172.16.0.0/12"),
+ IPv4Network("192.0.0.0/29"),
+ IPv4Network("192.0.0.170/31"),
+ IPv4Network("192.0.2.0/24"),
+ IPv4Network("192.168.0.0/16"),
+ IPv4Network("198.18.0.0/15"),
+ IPv4Network("198.51.100.0/24"),
+ IPv4Network("203.0.113.0/24"),
+ IPv4Network("240.0.0.0/4"),
+ IPv4Network("255.255.255.255/32"),
+ ]
+
+ _reserved_network = IPv4Network("240.0.0.0/4")
+
+ _unspecified_address = IPv4Address("0.0.0.0")
+
+
+IPv4Address._constants = _IPv4Constants
+
+
+class _BaseV6(object):
+
+ """Base IPv6 object.
+
+ The following methods are used by IPv6 objects in both single IP
+ addresses and networks.
+
+ """
+
+ __slots__ = ()
+ _version = 6
+ _ALL_ONES = (2 ** IPV6LENGTH) - 1
+ _HEXTET_COUNT = 8
+ _HEX_DIGITS = frozenset("0123456789ABCDEFabcdef")
+ _max_prefixlen = IPV6LENGTH
+
+ # There are only a bunch of valid v6 netmasks, so we cache them all
+ # when constructed (see _make_netmask()).
+ _netmask_cache = {}
+
+ @classmethod
+ def _make_netmask(cls, arg):
+ """Make a (netmask, prefix_len) tuple from the given argument.
+
+ Argument can be:
+ - an integer (the prefix length)
+ - a string representing the prefix length (e.g. "24")
+ - a string representing the prefix netmask (e.g. "255.255.255.0")
+ """
+ if arg not in cls._netmask_cache:
+ if isinstance(arg, _compat_int_types):
+ prefixlen = arg
+ else:
+ prefixlen = cls._prefix_from_prefix_string(arg)
+ netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen))
+ cls._netmask_cache[arg] = netmask, prefixlen
+ return cls._netmask_cache[arg]
+
+ @classmethod
+ def _ip_int_from_string(cls, ip_str):
+ """Turn an IPv6 ip_str into an integer.
+
+ Args:
+ ip_str: A string, the IPv6 ip_str.
+
+ Returns:
+ An int, the IPv6 address
+
+ Raises:
+ AddressValueError: if ip_str isn't a valid IPv6 Address.
+
+ """
+ if not ip_str:
+ raise AddressValueError("Address cannot be empty")
+
+ parts = ip_str.split(":")
+
+ # An IPv6 address needs at least 2 colons (3 parts).
+ _min_parts = 3
+ if len(parts) < _min_parts:
+ msg = "At least %d parts expected in %r" % (_min_parts, ip_str)
+ raise AddressValueError(msg)
+
+ # If the address has an IPv4-style suffix, convert it to hexadecimal.
+ if "." in parts[-1]:
+ try:
+ ipv4_int = IPv4Address(parts.pop())._ip
+ except AddressValueError as exc:
+ raise AddressValueError("%s in %r" % (exc, ip_str))
+ parts.append("%x" % ((ipv4_int >> 16) & 0xFFFF))
+ parts.append("%x" % (ipv4_int & 0xFFFF))
+
+ # An IPv6 address can't have more than 8 colons (9 parts).
+ # The extra colon comes from using the "::" notation for a single
+ # leading or trailing zero part.
+ _max_parts = cls._HEXTET_COUNT + 1
+ if len(parts) > _max_parts:
+ msg = "At most %d colons permitted in %r" % (
+ _max_parts - 1,
+ ip_str,
+ )
+ raise AddressValueError(msg)
+
+ # Disregarding the endpoints, find '::' with nothing in between.
+ # This indicates that a run of zeroes has been skipped.
+ skip_index = None
+ for i in _compat_range(1, len(parts) - 1):
+ if not parts[i]:
+ if skip_index is not None:
+ # Can't have more than one '::'
+ msg = "At most one '::' permitted in %r" % ip_str
+ raise AddressValueError(msg)
+ skip_index = i
+
+ # parts_hi is the number of parts to copy from above/before the '::'
+ # parts_lo is the number of parts to copy from below/after the '::'
+ if skip_index is not None:
+ # If we found a '::', then check if it also covers the endpoints.
+ parts_hi = skip_index
+ parts_lo = len(parts) - skip_index - 1
+ if not parts[0]:
+ parts_hi -= 1
+ if parts_hi:
+ msg = "Leading ':' only permitted as part of '::' in %r"
+ raise AddressValueError(msg % ip_str) # ^: requires ^::
+ if not parts[-1]:
+ parts_lo -= 1
+ if parts_lo:
+ msg = "Trailing ':' only permitted as part of '::' in %r"
+ raise AddressValueError(msg % ip_str) # :$ requires ::$
+ parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo)
+ if parts_skipped < 1:
+ msg = "Expected at most %d other parts with '::' in %r"
+ raise AddressValueError(msg % (cls._HEXTET_COUNT - 1, ip_str))
+ else:
+ # Otherwise, allocate the entire address to parts_hi. The
+ # endpoints could still be empty, but _parse_hextet() will check
+ # for that.
+ if len(parts) != cls._HEXTET_COUNT:
+ msg = "Exactly %d parts expected without '::' in %r"
+ raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str))
+ if not parts[0]:
+ msg = "Leading ':' only permitted as part of '::' in %r"
+ raise AddressValueError(msg % ip_str) # ^: requires ^::
+ if not parts[-1]:
+ msg = "Trailing ':' only permitted as part of '::' in %r"
+ raise AddressValueError(msg % ip_str) # :$ requires ::$
+ parts_hi = len(parts)
+ parts_lo = 0
+ parts_skipped = 0
+
+ try:
+ # Now, parse the hextets into a 128-bit integer.
+ ip_int = 0
+ for i in range(parts_hi):
+ ip_int <<= 16
+ ip_int |= cls._parse_hextet(parts[i])
+ ip_int <<= 16 * parts_skipped
+ for i in range(-parts_lo, 0):
+ ip_int <<= 16
+ ip_int |= cls._parse_hextet(parts[i])
+ return ip_int
+ except ValueError as exc:
+ raise AddressValueError("%s in %r" % (exc, ip_str))
+
+ @classmethod
+ def _parse_hextet(cls, hextet_str):
+ """Convert an IPv6 hextet string into an integer.
+
+ Args:
+ hextet_str: A string, the number to parse.
+
+ Returns:
+ The hextet as an integer.
+
+ Raises:
+ ValueError: if the input isn't strictly a hex number from
+ [0..FFFF].
+
+ """
+ # Whitelist the characters, since int() allows a lot of bizarre stuff.
+ if not cls._HEX_DIGITS.issuperset(hextet_str):
+ raise ValueError("Only hex digits permitted in %r" % hextet_str)
+ # We do the length check second, since the invalid character error
+ # is likely to be more informative for the user
+ if len(hextet_str) > 4:
+ msg = "At most 4 characters permitted in %r"
+ raise ValueError(msg % hextet_str)
+ # Length check means we can skip checking the integer value
+ return int(hextet_str, 16)
+
+ @classmethod
+ def _compress_hextets(cls, hextets):
+ """Compresses a list of hextets.
+
+ Compresses a list of strings, replacing the longest continuous
+ sequence of "0" in the list with "" and adding empty strings at
+ the beginning or at the end of the string such that subsequently
+ calling ":".join(hextets) will produce the compressed version of
+ the IPv6 address.
+
+ Args:
+ hextets: A list of strings, the hextets to compress.
+
+ Returns:
+ A list of strings.
+
+ """
+ best_doublecolon_start = -1
+ best_doublecolon_len = 0
+ doublecolon_start = -1
+ doublecolon_len = 0
+ for index, hextet in enumerate(hextets):
+ if hextet == "0":
+ doublecolon_len += 1
+ if doublecolon_start == -1:
+ # Start of a sequence of zeros.
+ doublecolon_start = index
+ if doublecolon_len > best_doublecolon_len:
+ # This is the longest sequence of zeros so far.
+ best_doublecolon_len = doublecolon_len
+ best_doublecolon_start = doublecolon_start
+ else:
+ doublecolon_len = 0
+ doublecolon_start = -1
+
+ if best_doublecolon_len > 1:
+ best_doublecolon_end = (
+ best_doublecolon_start + best_doublecolon_len
+ )
+ # For zeros at the end of the address.
+ if best_doublecolon_end == len(hextets):
+ hextets += [""]
+ hextets[best_doublecolon_start:best_doublecolon_end] = [""]
+ # For zeros at the beginning of the address.
+ if best_doublecolon_start == 0:
+ hextets = [""] + hextets
+
+ return hextets
+
+ @classmethod
+ def _string_from_ip_int(cls, ip_int=None):
+ """Turns a 128-bit integer into hexadecimal notation.
+
+ Args:
+ ip_int: An integer, the IP address.
+
+ Returns:
+ A string, the hexadecimal representation of the address.
+
+ Raises:
+ ValueError: The address is bigger than 128 bits of all ones.
+
+ """
+ if ip_int is None:
+ ip_int = int(cls._ip)
+
+ if ip_int > cls._ALL_ONES:
+ raise ValueError("IPv6 address is too large")
+
+ hex_str = "%032x" % ip_int
+ hextets = ["%x" % int(hex_str[x : x + 4], 16) for x in range(0, 32, 4)]
+
+ hextets = cls._compress_hextets(hextets)
+ return ":".join(hextets)
+
+ def _explode_shorthand_ip_string(self):
+ """Expand a shortened IPv6 address.
+
+ Args:
+ ip_str: A string, the IPv6 address.
+
+ Returns:
+ A string, the expanded IPv6 address.
+
+ """
+ if isinstance(self, IPv6Network):
+ ip_str = _compat_str(self.network_address)
+ elif isinstance(self, IPv6Interface):
+ ip_str = _compat_str(self.ip)
+ else:
+ ip_str = _compat_str(self)
+
+ ip_int = self._ip_int_from_string(ip_str)
+ hex_str = "%032x" % ip_int
+ parts = [hex_str[x : x + 4] for x in range(0, 32, 4)]
+ if isinstance(self, (_BaseNetwork, IPv6Interface)):
+ return "%s/%d" % (":".join(parts), self._prefixlen)
+ return ":".join(parts)
+
+ def _reverse_pointer(self):
+ """Return the reverse DNS pointer name for the IPv6 address.
+
+ This implements the method described in RFC3596 2.5.
+
+ """
+ reverse_chars = self.exploded[::-1].replace(":", "")
+ return ".".join(reverse_chars) + ".ip6.arpa"
+
+ @property
+ def max_prefixlen(self):
+ return self._max_prefixlen
+
+ @property
+ def version(self):
+ return self._version
+
+
+class IPv6Address(_BaseV6, _BaseAddress):
+
+ """Represent and manipulate single IPv6 Addresses."""
+
+ __slots__ = ("_ip", "__weakref__")
+
+ def __init__(self, address):
+ """Instantiate a new IPv6 address object.
+
+ Args:
+ address: A string or integer representing the IP
+
+ Additionally, an integer can be passed, so
+ IPv6Address('2001:db8::') ==
+ IPv6Address(42540766411282592856903984951653826560)
+ or, more generally
+ IPv6Address(int(IPv6Address('2001:db8::'))) ==
+ IPv6Address('2001:db8::')
+
+ Raises:
+ AddressValueError: If address isn't a valid IPv6 address.
+
+ """
+ # Efficient constructor from integer.
+ if isinstance(address, _compat_int_types):
+ self._check_int_address(address)
+ self._ip = address
+ return
+
+ # Constructing from a packed address
+ if isinstance(address, bytes):
+ self._check_packed_address(address, 16)
+ bvs = _compat_bytes_to_byte_vals(address)
+ self._ip = _compat_int_from_byte_vals(bvs, "big")
+ return
+
+ # Assume input argument to be string or any object representation
+ # which converts into a formatted IP string.
+ addr_str = _compat_str(address)
+ if "/" in addr_str:
+ raise AddressValueError("Unexpected '/' in %r" % address)
+ self._ip = self._ip_int_from_string(addr_str)
+
+ @property
+ def packed(self):
+ """The binary representation of this address."""
+ return v6_int_to_packed(self._ip)
+
+ @property
+ def is_multicast(self):
+ """Test if the address is reserved for multicast use.
+
+ Returns:
+ A boolean, True if the address is a multicast address.
+ See RFC 2373 2.7 for details.
+
+ """
+ return self in self._constants._multicast_network
+
+ @property
+ def is_reserved(self):
+ """Test if the address is otherwise IETF reserved.
+
+ Returns:
+ A boolean, True if the address is within one of the
+ reserved IPv6 Network ranges.
+
+ """
+ return any(self in x for x in self._constants._reserved_networks)
+
+ @property
+ def is_link_local(self):
+ """Test if the address is reserved for link-local.
+
+ Returns:
+ A boolean, True if the address is reserved per RFC 4291.
+
+ """
+ return self in self._constants._linklocal_network
+
+ @property
+ def is_site_local(self):
+ """Test if the address is reserved for site-local.
+
+ Note that the site-local address space has been deprecated by RFC 3879.
+ Use is_private to test if this address is in the space of unique local
+ addresses as defined by RFC 4193.
+
+ Returns:
+ A boolean, True if the address is reserved per RFC 3513 2.5.6.
+
+ """
+ return self in self._constants._sitelocal_network
+
+ @property
+ def is_private(self):
+ """Test if this address is allocated for private networks.
+
+ Returns:
+ A boolean, True if the address is reserved per
+ iana-ipv6-special-registry.
+
+ """
+ return any(self in net for net in self._constants._private_networks)
+
+ @property
+ def is_global(self):
+ """Test if this address is allocated for public networks.
+
+ Returns:
+ A boolean, true if the address is not reserved per
+ iana-ipv6-special-registry.
+
+ """
+ return not self.is_private
+
+ @property
+ def is_unspecified(self):
+ """Test if the address is unspecified.
+
+ Returns:
+ A boolean, True if this is the unspecified address as defined in
+ RFC 2373 2.5.2.
+
+ """
+ return self._ip == 0
+
+ @property
+ def is_loopback(self):
+ """Test if the address is a loopback address.
+
+ Returns:
+ A boolean, True if the address is a loopback address as defined in
+ RFC 2373 2.5.3.
+
+ """
+ return self._ip == 1
+
+ @property
+ def ipv4_mapped(self):
+ """Return the IPv4 mapped address.
+
+ Returns:
+ If the IPv6 address is a v4 mapped address, return the
+ IPv4 mapped address. Return None otherwise.
+
+ """
+ if (self._ip >> 32) != 0xFFFF:
+ return None
+ return IPv4Address(self._ip & 0xFFFFFFFF)
+
+ @property
+ def teredo(self):
+ """Tuple of embedded teredo IPs.
+
+ Returns:
+ Tuple of the (server, client) IPs or None if the address
+ doesn't appear to be a teredo address (doesn't start with
+ 2001::/32)
+
+ """
+ if (self._ip >> 96) != 0x20010000:
+ return None
+ return (
+ IPv4Address((self._ip >> 64) & 0xFFFFFFFF),
+ IPv4Address(~self._ip & 0xFFFFFFFF),
+ )
+
+ @property
+ def sixtofour(self):
+ """Return the IPv4 6to4 embedded address.
+
+ Returns:
+ The IPv4 6to4-embedded address if present or None if the
+ address doesn't appear to contain a 6to4 embedded address.
+
+ """
+ if (self._ip >> 112) != 0x2002:
+ return None
+ return IPv4Address((self._ip >> 80) & 0xFFFFFFFF)
+
+
+class IPv6Interface(IPv6Address):
+ def __init__(self, address):
+ if isinstance(address, (bytes, _compat_int_types)):
+ IPv6Address.__init__(self, address)
+ self.network = IPv6Network(self._ip)
+ self._prefixlen = self._max_prefixlen
+ return
+ if isinstance(address, tuple):
+ IPv6Address.__init__(self, address[0])
+ if len(address) > 1:
+ self._prefixlen = int(address[1])
+ else:
+ self._prefixlen = self._max_prefixlen
+ self.network = IPv6Network(address, strict=False)
+ self.netmask = self.network.netmask
+ self.hostmask = self.network.hostmask
+ return
+
+ addr = _split_optional_netmask(address)
+ IPv6Address.__init__(self, addr[0])
+ self.network = IPv6Network(address, strict=False)
+ self.netmask = self.network.netmask
+ self._prefixlen = self.network._prefixlen
+ self.hostmask = self.network.hostmask
+
+ def __str__(self):
+ return "%s/%d" % (
+ self._string_from_ip_int(self._ip),
+ self.network.prefixlen,
+ )
+
+ def __eq__(self, other):
+ address_equal = IPv6Address.__eq__(self, other)
+ if not address_equal or address_equal is NotImplemented:
+ return address_equal
+ try:
+ return self.network == other.network
+ except AttributeError:
+ # An interface with an associated network is NOT the
+ # same as an unassociated address. That's why the hash
+ # takes the extra info into account.
+ return False
+
+ def __lt__(self, other):
+ address_less = IPv6Address.__lt__(self, other)
+ if address_less is NotImplemented:
+ return NotImplemented
+ try:
+ return (
+ self.network < other.network
+ or self.network == other.network
+ and address_less
+ )
+ except AttributeError:
+ # We *do* allow addresses and interfaces to be sorted. The
+ # unassociated address is considered less than all interfaces.
+ return False
+
+ def __hash__(self):
+ return self._ip ^ self._prefixlen ^ int(self.network.network_address)
+
+ __reduce__ = _IPAddressBase.__reduce__
+
+ @property
+ def ip(self):
+ return IPv6Address(self._ip)
+
+ @property
+ def with_prefixlen(self):
+ return "%s/%s" % (self._string_from_ip_int(self._ip), self._prefixlen)
+
+ @property
+ def with_netmask(self):
+ return "%s/%s" % (self._string_from_ip_int(self._ip), self.netmask)
+
+ @property
+ def with_hostmask(self):
+ return "%s/%s" % (self._string_from_ip_int(self._ip), self.hostmask)
+
+ @property
+ def is_unspecified(self):
+ return self._ip == 0 and self.network.is_unspecified
+
+ @property
+ def is_loopback(self):
+ return self._ip == 1 and self.network.is_loopback
+
+
+class IPv6Network(_BaseV6, _BaseNetwork):
+
+ """This class represents and manipulates 128-bit IPv6 networks.
+
+ Attributes: [examples for IPv6('2001:db8::1000/124')]
+ .network_address: IPv6Address('2001:db8::1000')
+ .hostmask: IPv6Address('::f')
+ .broadcast_address: IPv6Address('2001:db8::100f')
+ .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0')
+ .prefixlen: 124
+
+ """
+
+ # Class to use when creating address objects
+ _address_class = IPv6Address
+
+ def __init__(self, address, strict=True):
+ """Instantiate a new IPv6 Network object.
+
+ Args:
+ address: A string or integer representing the IPv6 network or the
+ IP and prefix/netmask.
+ '2001:db8::/128'
+ '2001:db8:0000:0000:0000:0000:0000:0000/128'
+ '2001:db8::'
+ are all functionally the same in IPv6. That is to say,
+ failing to provide a subnetmask will create an object with
+ a mask of /128.
+
+ Additionally, an integer can be passed, so
+ IPv6Network('2001:db8::') ==
+ IPv6Network(42540766411282592856903984951653826560)
+ or, more generally
+ IPv6Network(int(IPv6Network('2001:db8::'))) ==
+ IPv6Network('2001:db8::')
+
+ strict: A boolean. If true, ensure that we have been passed
+ A true network address, eg, 2001:db8::1000/124 and not an
+ IP address on a network, eg, 2001:db8::1/124.
+
+ Raises:
+ AddressValueError: If address isn't a valid IPv6 address.
+ NetmaskValueError: If the netmask isn't valid for
+ an IPv6 address.
+ ValueError: If strict was True and a network address was not
+ supplied.
+
+ """
+ _BaseNetwork.__init__(self, address)
+
+ # Efficient constructor from integer or packed address
+ if isinstance(address, (bytes, _compat_int_types)):
+ self.network_address = IPv6Address(address)
+ self.netmask, self._prefixlen = self._make_netmask(
+ self._max_prefixlen
+ )
+ return
+
+ if isinstance(address, tuple):
+ if len(address) > 1:
+ arg = address[1]
+ else:
+ arg = self._max_prefixlen
+ self.netmask, self._prefixlen = self._make_netmask(arg)
+ self.network_address = IPv6Address(address[0])
+ packed = int(self.network_address)
+ if packed & int(self.netmask) != packed:
+ if strict:
+ raise ValueError("%s has host bits set" % self)
+ else:
+ self.network_address = IPv6Address(
+ packed & int(self.netmask)
+ )
+ return
+
+ # Assume input argument to be string or any object representation
+ # which converts into a formatted IP prefix string.
+ addr = _split_optional_netmask(address)
+
+ self.network_address = IPv6Address(self._ip_int_from_string(addr[0]))
+
+ if len(addr) == 2:
+ arg = addr[1]
+ else:
+ arg = self._max_prefixlen
+ self.netmask, self._prefixlen = self._make_netmask(arg)
+
+ if strict:
+ if (
+ IPv6Address(int(self.network_address) & int(self.netmask))
+ != self.network_address
+ ):
+ raise ValueError("%s has host bits set" % self)
+ self.network_address = IPv6Address(
+ int(self.network_address) & int(self.netmask)
+ )
+
+ if self._prefixlen == (self._max_prefixlen - 1):
+ self.hosts = self.__iter__
+
+ def hosts(self):
+ """Generate Iterator over usable hosts in a network.
+
+ This is like __iter__ except it doesn't return the
+ Subnet-Router anycast address.
+
+ """
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ for x in _compat_range(network + 1, broadcast + 1):
+ yield self._address_class(x)
+
+ @property
+ def is_site_local(self):
+ """Test if the address is reserved for site-local.
+
+ Note that the site-local address space has been deprecated by RFC 3879.
+ Use is_private to test if this address is in the space of unique local
+ addresses as defined by RFC 4193.
+
+ Returns:
+ A boolean, True if the address is reserved per RFC 3513 2.5.6.
+
+ """
+ return (
+ self.network_address.is_site_local
+ and self.broadcast_address.is_site_local
+ )
+
+
+class _IPv6Constants(object):
+
+ _linklocal_network = IPv6Network("fe80::/10")
+
+ _multicast_network = IPv6Network("ff00::/8")
+
+ _private_networks = [
+ IPv6Network("::1/128"),
+ IPv6Network("::/128"),
+ IPv6Network("::ffff:0:0/96"),
+ IPv6Network("100::/64"),
+ IPv6Network("2001::/23"),
+ IPv6Network("2001:2::/48"),
+ IPv6Network("2001:db8::/32"),
+ IPv6Network("2001:10::/28"),
+ IPv6Network("fc00::/7"),
+ IPv6Network("fe80::/10"),
+ ]
+
+ _reserved_networks = [
+ IPv6Network("::/8"),
+ IPv6Network("100::/8"),
+ IPv6Network("200::/7"),
+ IPv6Network("400::/6"),
+ IPv6Network("800::/5"),
+ IPv6Network("1000::/4"),
+ IPv6Network("4000::/3"),
+ IPv6Network("6000::/3"),
+ IPv6Network("8000::/3"),
+ IPv6Network("A000::/3"),
+ IPv6Network("C000::/3"),
+ IPv6Network("E000::/4"),
+ IPv6Network("F000::/5"),
+ IPv6Network("F800::/6"),
+ IPv6Network("FE00::/9"),
+ ]
+
+ _sitelocal_network = IPv6Network("fec0::/10")
+
+
+IPv6Address._constants = _IPv6Constants
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py
new file mode 100644
index 00000000..68608d1b
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py
@@ -0,0 +1,27 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The base class for all resource modules
+"""
+
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.network import (
+ get_resource_connection,
+)
+
+
+class ConfigBase(object):
+ """ The base class for all resource modules
+ """
+
+ ACTION_STATES = ["merged", "replaced", "overridden", "deleted"]
+
+ def __init__(self, module):
+ self._module = module
+ self.state = module.params["state"]
+ self._connection = None
+
+ if self.state not in ["rendered", "parsed"]:
+ self._connection = get_resource_connection(module)
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py
new file mode 100644
index 00000000..bc458eb5
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py
@@ -0,0 +1,473 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# (c) 2016 Red Hat Inc.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+import re
+import hashlib
+
+from ansible.module_utils.six.moves import zip
+from ansible.module_utils._text import to_bytes, to_native
+
+DEFAULT_COMMENT_TOKENS = ["#", "!", "/*", "*/", "echo"]
+
+DEFAULT_IGNORE_LINES_RE = set(
+ [
+ re.compile(r"Using \d+ out of \d+ bytes"),
+ re.compile(r"Building configuration"),
+ re.compile(r"Current configuration : \d+ bytes"),
+ ]
+)
+
+
+try:
+ Pattern = re._pattern_type
+except AttributeError:
+ Pattern = re.Pattern
+
+
+class ConfigLine(object):
+ def __init__(self, raw):
+ self.text = str(raw).strip()
+ self.raw = raw
+ self._children = list()
+ self._parents = list()
+
+ def __str__(self):
+ return self.raw
+
+ def __eq__(self, other):
+ return self.line == other.line
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __getitem__(self, key):
+ for item in self._children:
+ if item.text == key:
+ return item
+ raise KeyError(key)
+
+ @property
+ def line(self):
+ line = self.parents
+ line.append(self.text)
+ return " ".join(line)
+
+ @property
+ def children(self):
+ return _obj_to_text(self._children)
+
+ @property
+ def child_objs(self):
+ return self._children
+
+ @property
+ def parents(self):
+ return _obj_to_text(self._parents)
+
+ @property
+ def path(self):
+ config = _obj_to_raw(self._parents)
+ config.append(self.raw)
+ return "\n".join(config)
+
+ @property
+ def has_children(self):
+ return len(self._children) > 0
+
+ @property
+ def has_parents(self):
+ return len(self._parents) > 0
+
+ def add_child(self, obj):
+ if not isinstance(obj, ConfigLine):
+ raise AssertionError("child must be of type `ConfigLine`")
+ self._children.append(obj)
+
+
+def ignore_line(text, tokens=None):
+ for item in tokens or DEFAULT_COMMENT_TOKENS:
+ if text.startswith(item):
+ return True
+ for regex in DEFAULT_IGNORE_LINES_RE:
+ if regex.match(text):
+ return True
+
+
+def _obj_to_text(x):
+ return [o.text for o in x]
+
+
+def _obj_to_raw(x):
+ return [o.raw for o in x]
+
+
+def _obj_to_block(objects, visited=None):
+ items = list()
+ for o in objects:
+ if o not in items:
+ items.append(o)
+ for child in o._children:
+ if child not in items:
+ items.append(child)
+ return _obj_to_raw(items)
+
+
+def dumps(objects, output="block", comments=False):
+ if output == "block":
+ items = _obj_to_block(objects)
+ elif output == "commands":
+ items = _obj_to_text(objects)
+ elif output == "raw":
+ items = _obj_to_raw(objects)
+ else:
+ raise TypeError("unknown value supplied for keyword output")
+
+ if output == "block":
+ if comments:
+ for index, item in enumerate(items):
+ nextitem = index + 1
+ if (
+ nextitem < len(items)
+ and not item.startswith(" ")
+ and items[nextitem].startswith(" ")
+ ):
+ item = "!\n%s" % item
+ items[index] = item
+ items.append("!")
+ items.append("end")
+
+ return "\n".join(items)
+
+
+class NetworkConfig(object):
+ def __init__(self, indent=1, contents=None, ignore_lines=None):
+ self._indent = indent
+ self._items = list()
+ self._config_text = None
+
+ if ignore_lines:
+ for item in ignore_lines:
+ if not isinstance(item, Pattern):
+ item = re.compile(item)
+ DEFAULT_IGNORE_LINES_RE.add(item)
+
+ if contents:
+ self.load(contents)
+
+ @property
+ def items(self):
+ return self._items
+
+ @property
+ def config_text(self):
+ return self._config_text
+
+ @property
+ def sha1(self):
+ sha1 = hashlib.sha1()
+ sha1.update(to_bytes(str(self), errors="surrogate_or_strict"))
+ return sha1.digest()
+
+ def __getitem__(self, key):
+ for line in self:
+ if line.text == key:
+ return line
+ raise KeyError(key)
+
+ def __iter__(self):
+ return iter(self._items)
+
+ def __str__(self):
+ return "\n".join([c.raw for c in self.items])
+
+ def __len__(self):
+ return len(self._items)
+
+ def load(self, s):
+ self._config_text = s
+ self._items = self.parse(s)
+
+ def loadfp(self, fp):
+ with open(fp) as f:
+ return self.load(f.read())
+
+ def parse(self, lines, comment_tokens=None):
+ toplevel = re.compile(r"\S")
+ childline = re.compile(r"^\s*(.+)$")
+ entry_reg = re.compile(r"([{};])")
+
+ ancestors = list()
+ config = list()
+
+ indents = [0]
+
+ for linenum, line in enumerate(
+ to_native(lines, errors="surrogate_or_strict").split("\n")
+ ):
+ text = entry_reg.sub("", line).strip()
+
+ cfg = ConfigLine(line)
+
+ if not text or ignore_line(text, comment_tokens):
+ continue
+
+ # handle top level commands
+ if toplevel.match(line):
+ ancestors = [cfg]
+ indents = [0]
+
+ # handle sub level commands
+ else:
+ match = childline.match(line)
+ line_indent = match.start(1)
+
+ if line_indent < indents[-1]:
+ while indents[-1] > line_indent:
+ indents.pop()
+
+ if line_indent > indents[-1]:
+ indents.append(line_indent)
+
+ curlevel = len(indents) - 1
+ parent_level = curlevel - 1
+
+ cfg._parents = ancestors[:curlevel]
+
+ if curlevel > len(ancestors):
+ config.append(cfg)
+ continue
+
+ for i in range(curlevel, len(ancestors)):
+ ancestors.pop()
+
+ ancestors.append(cfg)
+ ancestors[parent_level].add_child(cfg)
+
+ config.append(cfg)
+
+ return config
+
+ def get_object(self, path):
+ for item in self.items:
+ if item.text == path[-1]:
+ if item.parents == path[:-1]:
+ return item
+
+ def get_block(self, path):
+ if not isinstance(path, list):
+ raise AssertionError("path argument must be a list object")
+ obj = self.get_object(path)
+ if not obj:
+ raise ValueError("path does not exist in config")
+ return self._expand_block(obj)
+
+ def get_block_config(self, path):
+ block = self.get_block(path)
+ return dumps(block, "block")
+
+ def _expand_block(self, configobj, S=None):
+ if S is None:
+ S = list()
+ S.append(configobj)
+ for child in configobj._children:
+ if child in S:
+ continue
+ self._expand_block(child, S)
+ return S
+
+ def _diff_line(self, other):
+ updates = list()
+ for item in self.items:
+ if item not in other:
+ updates.append(item)
+ return updates
+
+ def _diff_strict(self, other):
+ updates = list()
+ # block extracted from other does not have all parents
+ # but the last one. In case of multiple parents we need
+ # to add additional parents.
+ if other and isinstance(other, list) and len(other) > 0:
+ start_other = other[0]
+ if start_other.parents:
+ for parent in start_other.parents:
+ other.insert(0, ConfigLine(parent))
+ for index, line in enumerate(self.items):
+ try:
+ if str(line).strip() != str(other[index]).strip():
+ updates.append(line)
+ except (AttributeError, IndexError):
+ updates.append(line)
+ return updates
+
+ def _diff_exact(self, other):
+ updates = list()
+ if len(other) != len(self.items):
+ updates.extend(self.items)
+ else:
+ for ours, theirs in zip(self.items, other):
+ if ours != theirs:
+ updates.extend(self.items)
+ break
+ return updates
+
+ def difference(self, other, match="line", path=None, replace=None):
+ """Perform a config diff against the another network config
+
+ :param other: instance of NetworkConfig to diff against
+ :param match: type of diff to perform. valid values are 'line',
+ 'strict', 'exact'
+ :param path: context in the network config to filter the diff
+ :param replace: the method used to generate the replacement lines.
+ valid values are 'block', 'line'
+
+ :returns: a string of lines that are different
+ """
+ if path and match != "line":
+ try:
+ other = other.get_block(path)
+ except ValueError:
+ other = list()
+ else:
+ other = other.items
+
+ # generate a list of ConfigLines that aren't in other
+ meth = getattr(self, "_diff_%s" % match)
+ updates = meth(other)
+
+ if replace == "block":
+ parents = list()
+ for item in updates:
+ if not item.has_parents:
+ parents.append(item)
+ else:
+ for p in item._parents:
+ if p not in parents:
+ parents.append(p)
+
+ updates = list()
+ for item in parents:
+ updates.extend(self._expand_block(item))
+
+ visited = set()
+ expanded = list()
+
+ for item in updates:
+ for p in item._parents:
+ if p.line not in visited:
+ visited.add(p.line)
+ expanded.append(p)
+ expanded.append(item)
+ visited.add(item.line)
+
+ return expanded
+
+ def add(self, lines, parents=None):
+ ancestors = list()
+ offset = 0
+ obj = None
+
+ # global config command
+ if not parents:
+ for line in lines:
+ # handle ignore lines
+ if ignore_line(line):
+ continue
+
+ item = ConfigLine(line)
+ item.raw = line
+ if item not in self.items:
+ self.items.append(item)
+
+ else:
+ for index, p in enumerate(parents):
+ try:
+ i = index + 1
+ obj = self.get_block(parents[:i])[0]
+ ancestors.append(obj)
+
+ except ValueError:
+ # add parent to config
+ offset = index * self._indent
+ obj = ConfigLine(p)
+ obj.raw = p.rjust(len(p) + offset)
+ if ancestors:
+ obj._parents = list(ancestors)
+ ancestors[-1]._children.append(obj)
+ self.items.append(obj)
+ ancestors.append(obj)
+
+ # add child objects
+ for line in lines:
+ # handle ignore lines
+ if ignore_line(line):
+ continue
+
+ # check if child already exists
+ for child in ancestors[-1]._children:
+ if child.text == line:
+ break
+ else:
+ offset = len(parents) * self._indent
+ item = ConfigLine(line)
+ item.raw = line.rjust(len(line) + offset)
+ item._parents = ancestors
+ ancestors[-1]._children.append(item)
+ self.items.append(item)
+
+
+class CustomNetworkConfig(NetworkConfig):
+ def items_text(self):
+ return [item.text for item in self.items]
+
+ def expand_section(self, configobj, S=None):
+ if S is None:
+ S = list()
+ S.append(configobj)
+ for child in configobj.child_objs:
+ if child in S:
+ continue
+ self.expand_section(child, S)
+ return S
+
+ def to_block(self, section):
+ return "\n".join([item.raw for item in section])
+
+ def get_section(self, path):
+ try:
+ section = self.get_section_objects(path)
+ return self.to_block(section)
+ except ValueError:
+ return list()
+
+ def get_section_objects(self, path):
+ if not isinstance(path, list):
+ path = [path]
+ obj = self.get_object(path)
+ if not obj:
+ raise ValueError("path does not exist in config")
+ return self.expand_section(obj)
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py
new file mode 100644
index 00000000..477d3184
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py
@@ -0,0 +1,162 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The facts base class
+this contains methods common to all facts subsets
+"""
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.network import (
+ get_resource_connection,
+)
+from ansible.module_utils.six import iteritems
+
+
+class FactsBase(object):
+ """
+ The facts base class
+ """
+
+ def __init__(self, module):
+ self._module = module
+ self._warnings = []
+ self._gather_subset = module.params.get("gather_subset")
+ self._gather_network_resources = module.params.get(
+ "gather_network_resources"
+ )
+ self._connection = None
+ if module.params.get("state") not in ["rendered", "parsed"]:
+ self._connection = get_resource_connection(module)
+
+ self.ansible_facts = {"ansible_network_resources": {}}
+ self.ansible_facts["ansible_net_gather_network_resources"] = list()
+ self.ansible_facts["ansible_net_gather_subset"] = list()
+
+ if not self._gather_subset:
+ self._gather_subset = ["!config"]
+ if not self._gather_network_resources:
+ self._gather_network_resources = ["!all"]
+
+ def gen_runable(self, subsets, valid_subsets, resource_facts=False):
+ """ Generate the runable subset
+
+ :param module: The module instance
+ :param subsets: The provided subsets
+ :param valid_subsets: The valid subsets
+ :param resource_facts: A boolean flag
+ :rtype: list
+ :returns: The runable subsets
+ """
+ runable_subsets = set()
+ exclude_subsets = set()
+ minimal_gather_subset = set()
+ if not resource_facts:
+ minimal_gather_subset = frozenset(["default"])
+
+ for subset in subsets:
+ if subset == "all":
+ runable_subsets.update(valid_subsets)
+ continue
+ if subset == "min" and minimal_gather_subset:
+ runable_subsets.update(minimal_gather_subset)
+ continue
+ if subset.startswith("!"):
+ subset = subset[1:]
+ if subset == "min":
+ exclude_subsets.update(minimal_gather_subset)
+ continue
+ if subset == "all":
+ exclude_subsets.update(
+ valid_subsets - minimal_gather_subset
+ )
+ continue
+ exclude = True
+ else:
+ exclude = False
+
+ if subset not in valid_subsets:
+ self._module.fail_json(
+ msg="Subset must be one of [%s], got %s"
+ % (
+ ", ".join(sorted([item for item in valid_subsets])),
+ subset,
+ )
+ )
+
+ if exclude:
+ exclude_subsets.add(subset)
+ else:
+ runable_subsets.add(subset)
+
+ if not runable_subsets:
+ runable_subsets.update(valid_subsets)
+ runable_subsets.difference_update(exclude_subsets)
+ return runable_subsets
+
+ def get_network_resources_facts(
+ self, facts_resource_obj_map, resource_facts_type=None, data=None
+ ):
+ """
+ :param fact_resource_subsets:
+ :param data: previously collected configuration
+ :return:
+ """
+ if not resource_facts_type:
+ resource_facts_type = self._gather_network_resources
+
+ restorun_subsets = self.gen_runable(
+ resource_facts_type,
+ frozenset(facts_resource_obj_map.keys()),
+ resource_facts=True,
+ )
+ if restorun_subsets:
+ self.ansible_facts["ansible_net_gather_network_resources"] = list(
+ restorun_subsets
+ )
+ instances = list()
+ for key in restorun_subsets:
+ fact_cls_obj = facts_resource_obj_map.get(key)
+ if fact_cls_obj:
+ instances.append(fact_cls_obj(self._module))
+ else:
+ self._warnings.extend(
+ [
+ "network resource fact gathering for '%s' is not supported"
+ % key
+ ]
+ )
+
+ for inst in instances:
+ inst.populate_facts(self._connection, self.ansible_facts, data)
+
+ def get_network_legacy_facts(
+ self, fact_legacy_obj_map, legacy_facts_type=None
+ ):
+ if not legacy_facts_type:
+ legacy_facts_type = self._gather_subset
+
+ runable_subsets = self.gen_runable(
+ legacy_facts_type, frozenset(fact_legacy_obj_map.keys())
+ )
+ if runable_subsets:
+ facts = dict()
+ # default subset should always returned be with legacy facts subsets
+ if "default" not in runable_subsets:
+ runable_subsets.add("default")
+ self.ansible_facts["ansible_net_gather_subset"] = list(
+ runable_subsets
+ )
+
+ instances = list()
+ for key in runable_subsets:
+ instances.append(fact_legacy_obj_map[key](self._module))
+
+ for inst in instances:
+ inst.populate()
+ facts.update(inst.facts)
+ self._warnings.extend(inst.warnings)
+
+ for key, value in iteritems(facts):
+ key = "ansible_net_%s" % key
+ self.ansible_facts[key] = value
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py
new file mode 100644
index 00000000..53a91e8c
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py
@@ -0,0 +1,179 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# (c) 2017 Red Hat Inc.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+import sys
+
+from ansible.module_utils._text import to_text, to_bytes
+from ansible.module_utils.connection import Connection, ConnectionError
+
+try:
+ from ncclient.xml_ import NCElement, new_ele, sub_ele
+
+ HAS_NCCLIENT = True
+except (ImportError, AttributeError):
+ HAS_NCCLIENT = False
+
+try:
+ from lxml.etree import Element, fromstring, XMLSyntaxError
+except ImportError:
+ from xml.etree.ElementTree import Element, fromstring
+
+ if sys.version_info < (2, 7):
+ from xml.parsers.expat import ExpatError as XMLSyntaxError
+ else:
+ from xml.etree.ElementTree import ParseError as XMLSyntaxError
+
+NS_MAP = {"nc": "urn:ietf:params:xml:ns:netconf:base:1.0"}
+
+
+def exec_rpc(module, *args, **kwargs):
+ connection = NetconfConnection(module._socket_path)
+ return connection.execute_rpc(*args, **kwargs)
+
+
+class NetconfConnection(Connection):
+ def __init__(self, socket_path):
+ super(NetconfConnection, self).__init__(socket_path)
+
+ def __rpc__(self, name, *args, **kwargs):
+ """Executes the json-rpc and returns the output received
+ from remote device.
+ :name: rpc method to be executed over connection plugin that implements jsonrpc 2.0
+ :args: Ordered list of params passed as arguments to rpc method
+ :kwargs: Dict of valid key, value pairs passed as arguments to rpc method
+
+ For usage refer the respective connection plugin docs.
+ """
+ self.check_rc = kwargs.pop("check_rc", True)
+ self.ignore_warning = kwargs.pop("ignore_warning", True)
+
+ response = self._exec_jsonrpc(name, *args, **kwargs)
+ if "error" in response:
+ rpc_error = response["error"].get("data")
+ return self.parse_rpc_error(
+ to_bytes(rpc_error, errors="surrogate_then_replace")
+ )
+
+ return fromstring(
+ to_bytes(response["result"], errors="surrogate_then_replace")
+ )
+
+ def parse_rpc_error(self, rpc_error):
+ if self.check_rc:
+ try:
+ error_root = fromstring(rpc_error)
+ root = Element("root")
+ root.append(error_root)
+
+ error_list = root.findall(".//nc:rpc-error", NS_MAP)
+ if not error_list:
+ raise ConnectionError(
+ to_text(rpc_error, errors="surrogate_then_replace")
+ )
+
+ warnings = []
+ for error in error_list:
+ message_ele = error.find("./nc:error-message", NS_MAP)
+
+ if message_ele is None:
+ message_ele = error.find("./nc:error-info", NS_MAP)
+
+ message = (
+ message_ele.text if message_ele is not None else None
+ )
+
+ severity = error.find("./nc:error-severity", NS_MAP).text
+
+ if (
+ severity == "warning"
+ and self.ignore_warning
+ and message is not None
+ ):
+ warnings.append(message)
+ else:
+ raise ConnectionError(
+ to_text(rpc_error, errors="surrogate_then_replace")
+ )
+ return warnings
+ except XMLSyntaxError:
+ raise ConnectionError(rpc_error)
+
+
+def transform_reply():
+ return b"""<xsl:stylesheet version="1.0" xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
+ <xsl:output method="xml" indent="no"/>
+
+ <xsl:template match="/|comment()|processing-instruction()">
+ <xsl:copy>
+ <xsl:apply-templates/>
+ </xsl:copy>
+ </xsl:template>
+
+ <xsl:template match="*">
+ <xsl:element name="{local-name()}">
+ <xsl:apply-templates select="@*|node()"/>
+ </xsl:element>
+ </xsl:template>
+
+ <xsl:template match="@*">
+ <xsl:attribute name="{local-name()}">
+ <xsl:value-of select="."/>
+ </xsl:attribute>
+ </xsl:template>
+ </xsl:stylesheet>
+ """
+
+
+# Note: Workaround for ncclient 0.5.3
+def remove_namespaces(data):
+ if not HAS_NCCLIENT:
+ raise ImportError(
+ "ncclient is required but does not appear to be installed. "
+ "It can be installed using `pip install ncclient`"
+ )
+ return NCElement(data, transform_reply()).data_xml
+
+
+def build_root_xml_node(tag):
+ return new_ele(tag)
+
+
+def build_child_xml_node(parent, tag, text=None, attrib=None):
+ element = sub_ele(parent, tag)
+ if text:
+ element.text = to_text(text)
+ if attrib:
+ element.attrib.update(attrib)
+ return element
+
+
+def build_subtree(parent, path):
+ element = parent
+ for field in path.split("/"):
+ sub_element = build_child_xml_node(element, field)
+ element = sub_element
+ return element
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py
new file mode 100644
index 00000000..555fc713
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py
@@ -0,0 +1,275 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# Copyright (c) 2015 Peter Sprygada, <psprygada@ansible.com>
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import traceback
+import json
+
+from ansible.module_utils._text import to_text, to_native
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.basic import env_fallback
+from ansible.module_utils.connection import Connection, ConnectionError
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.netconf import (
+ NetconfConnection,
+)
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import (
+ Cli,
+)
+from ansible.module_utils.six import iteritems
+
+
+NET_TRANSPORT_ARGS = dict(
+ host=dict(required=True),
+ port=dict(type="int"),
+ username=dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])),
+ password=dict(
+ no_log=True, fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"])
+ ),
+ ssh_keyfile=dict(
+ fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path"
+ ),
+ authorize=dict(
+ default=False,
+ fallback=(env_fallback, ["ANSIBLE_NET_AUTHORIZE"]),
+ type="bool",
+ ),
+ auth_pass=dict(
+ no_log=True, fallback=(env_fallback, ["ANSIBLE_NET_AUTH_PASS"])
+ ),
+ provider=dict(type="dict", no_log=True),
+ transport=dict(choices=list()),
+ timeout=dict(default=10, type="int"),
+)
+
+NET_CONNECTION_ARGS = dict()
+
+NET_CONNECTIONS = dict()
+
+
+def _transitional_argument_spec():
+ argument_spec = {}
+ for key, value in iteritems(NET_TRANSPORT_ARGS):
+ value["required"] = False
+ argument_spec[key] = value
+ return argument_spec
+
+
+def to_list(val):
+ if isinstance(val, (list, tuple)):
+ return list(val)
+ elif val is not None:
+ return [val]
+ else:
+ return list()
+
+
+class ModuleStub(object):
+ def __init__(self, argument_spec, fail_json):
+ self.params = dict()
+ for key, value in argument_spec.items():
+ self.params[key] = value.get("default")
+ self.fail_json = fail_json
+
+
+class NetworkError(Exception):
+ def __init__(self, msg, **kwargs):
+ super(NetworkError, self).__init__(msg)
+ self.kwargs = kwargs
+
+
+class Config(object):
+ def __init__(self, connection):
+ self.connection = connection
+
+ def __call__(self, commands, **kwargs):
+ lines = to_list(commands)
+ return self.connection.configure(lines, **kwargs)
+
+ def load_config(self, commands, **kwargs):
+ commands = to_list(commands)
+ return self.connection.load_config(commands, **kwargs)
+
+ def get_config(self, **kwargs):
+ return self.connection.get_config(**kwargs)
+
+ def save_config(self):
+ return self.connection.save_config()
+
+
+class NetworkModule(AnsibleModule):
+ def __init__(self, *args, **kwargs):
+ connect_on_load = kwargs.pop("connect_on_load", True)
+
+ argument_spec = NET_TRANSPORT_ARGS.copy()
+ argument_spec["transport"]["choices"] = NET_CONNECTIONS.keys()
+ argument_spec.update(NET_CONNECTION_ARGS.copy())
+
+ if kwargs.get("argument_spec"):
+ argument_spec.update(kwargs["argument_spec"])
+ kwargs["argument_spec"] = argument_spec
+
+ super(NetworkModule, self).__init__(*args, **kwargs)
+
+ self.connection = None
+ self._cli = None
+ self._config = None
+
+ try:
+ transport = self.params["transport"] or "__default__"
+ cls = NET_CONNECTIONS[transport]
+ self.connection = cls()
+ except KeyError:
+ self.fail_json(
+ msg="Unknown transport or no default transport specified"
+ )
+ except (TypeError, NetworkError) as exc:
+ self.fail_json(
+ msg=to_native(exc), exception=traceback.format_exc()
+ )
+
+ if connect_on_load:
+ self.connect()
+
+ @property
+ def cli(self):
+ if not self.connected:
+ self.connect()
+ if self._cli:
+ return self._cli
+ self._cli = Cli(self.connection)
+ return self._cli
+
+ @property
+ def config(self):
+ if not self.connected:
+ self.connect()
+ if self._config:
+ return self._config
+ self._config = Config(self.connection)
+ return self._config
+
+ @property
+ def connected(self):
+ return self.connection._connected
+
+ def _load_params(self):
+ super(NetworkModule, self)._load_params()
+ provider = self.params.get("provider") or dict()
+ for key, value in provider.items():
+ for args in [NET_TRANSPORT_ARGS, NET_CONNECTION_ARGS]:
+ if key in args:
+ if self.params.get(key) is None and value is not None:
+ self.params[key] = value
+
+ def connect(self):
+ try:
+ if not self.connected:
+ self.connection.connect(self.params)
+ if self.params["authorize"]:
+ self.connection.authorize(self.params)
+ self.log(
+ "connected to %s:%s using %s"
+ % (
+ self.params["host"],
+ self.params["port"],
+ self.params["transport"],
+ )
+ )
+ except NetworkError as exc:
+ self.fail_json(
+ msg=to_native(exc), exception=traceback.format_exc()
+ )
+
+ def disconnect(self):
+ try:
+ if self.connected:
+ self.connection.disconnect()
+ self.log("disconnected from %s" % self.params["host"])
+ except NetworkError as exc:
+ self.fail_json(
+ msg=to_native(exc), exception=traceback.format_exc()
+ )
+
+
+def register_transport(transport, default=False):
+ def register(cls):
+ NET_CONNECTIONS[transport] = cls
+ if default:
+ NET_CONNECTIONS["__default__"] = cls
+ return cls
+
+ return register
+
+
+def add_argument(key, value):
+ NET_CONNECTION_ARGS[key] = value
+
+
+def get_resource_connection(module):
+ if hasattr(module, "_connection"):
+ return module._connection
+
+ capabilities = get_capabilities(module)
+ network_api = capabilities.get("network_api")
+ if network_api in ("cliconf", "nxapi", "eapi", "exosapi"):
+ module._connection = Connection(module._socket_path)
+ elif network_api == "netconf":
+ module._connection = NetconfConnection(module._socket_path)
+ elif network_api == "local":
+ # This isn't supported, but we shouldn't fail here.
+ # Set the connection to a fake connection so it fails sensibly.
+ module._connection = LocalResourceConnection(module)
+ else:
+ module.fail_json(
+ msg="Invalid connection type {0!s}".format(network_api)
+ )
+
+ return module._connection
+
+
+def get_capabilities(module):
+ if hasattr(module, "capabilities"):
+ return module._capabilities
+ try:
+ capabilities = Connection(module._socket_path).get_capabilities()
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc, errors="surrogate_then_replace"))
+ except AssertionError:
+ # No socket_path, connection most likely local.
+ return dict(network_api="local")
+ module._capabilities = json.loads(capabilities)
+
+ return module._capabilities
+
+
+class LocalResourceConnection:
+ def __init__(self, module):
+ self.module = module
+
+ def get(self, *args, **kwargs):
+ self.module.fail_json(
+ msg="Network resource modules not supported over local connection."
+ )
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py
new file mode 100644
index 00000000..2dd1de9e
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py
@@ -0,0 +1,316 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# Copyright (c) 2015 Peter Sprygada, <psprygada@ansible.com>
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import re
+import shlex
+import time
+
+from ansible.module_utils.parsing.convert_bool import (
+ BOOLEANS_TRUE,
+ BOOLEANS_FALSE,
+)
+from ansible.module_utils.six import string_types, text_type
+from ansible.module_utils.six.moves import zip
+
+
+def to_list(val):
+ if isinstance(val, (list, tuple)):
+ return list(val)
+ elif val is not None:
+ return [val]
+ else:
+ return list()
+
+
+class FailedConditionsError(Exception):
+ def __init__(self, msg, failed_conditions):
+ super(FailedConditionsError, self).__init__(msg)
+ self.failed_conditions = failed_conditions
+
+
+class FailedConditionalError(Exception):
+ def __init__(self, msg, failed_conditional):
+ super(FailedConditionalError, self).__init__(msg)
+ self.failed_conditional = failed_conditional
+
+
+class AddCommandError(Exception):
+ def __init__(self, msg, command):
+ super(AddCommandError, self).__init__(msg)
+ self.command = command
+
+
+class AddConditionError(Exception):
+ def __init__(self, msg, condition):
+ super(AddConditionError, self).__init__(msg)
+ self.condition = condition
+
+
+class Cli(object):
+ def __init__(self, connection):
+ self.connection = connection
+ self.default_output = connection.default_output or "text"
+ self._commands = list()
+
+ @property
+ def commands(self):
+ return [str(c) for c in self._commands]
+
+ def __call__(self, commands, output=None):
+ objects = list()
+ for cmd in to_list(commands):
+ objects.append(self.to_command(cmd, output))
+ return self.connection.run_commands(objects)
+
+ def to_command(
+ self, command, output=None, prompt=None, response=None, **kwargs
+ ):
+ output = output or self.default_output
+ if isinstance(command, Command):
+ return command
+ if isinstance(prompt, string_types):
+ prompt = re.compile(re.escape(prompt))
+ return Command(
+ command, output, prompt=prompt, response=response, **kwargs
+ )
+
+ def add_commands(self, commands, output=None, **kwargs):
+ for cmd in commands:
+ self._commands.append(self.to_command(cmd, output, **kwargs))
+
+ def run_commands(self):
+ responses = self.connection.run_commands(self._commands)
+ for resp, cmd in zip(responses, self._commands):
+ cmd.response = resp
+
+ # wipe out the commands list to avoid issues if additional
+ # commands are executed later
+ self._commands = list()
+
+ return responses
+
+
+class Command(object):
+ def __init__(
+ self, command, output=None, prompt=None, response=None, **kwargs
+ ):
+
+ self.command = command
+ self.output = output
+ self.command_string = command
+
+ self.prompt = prompt
+ self.response = response
+
+ self.args = kwargs
+
+ def __str__(self):
+ return self.command_string
+
+
+class CommandRunner(object):
+ def __init__(self, module):
+ self.module = module
+
+ self.items = list()
+ self.conditionals = set()
+
+ self.commands = list()
+
+ self.retries = 10
+ self.interval = 1
+
+ self.match = "all"
+
+ self._default_output = module.connection.default_output
+
+ def add_command(
+ self, command, output=None, prompt=None, response=None, **kwargs
+ ):
+ if command in [str(c) for c in self.commands]:
+ raise AddCommandError(
+ "duplicated command detected", command=command
+ )
+ cmd = self.module.cli.to_command(
+ command, output=output, prompt=prompt, response=response, **kwargs
+ )
+ self.commands.append(cmd)
+
+ def get_command(self, command, output=None):
+ for cmd in self.commands:
+ if cmd.command == command:
+ return cmd.response
+ raise ValueError("command '%s' not found" % command)
+
+ def get_responses(self):
+ return [cmd.response for cmd in self.commands]
+
+ def add_conditional(self, condition):
+ try:
+ self.conditionals.add(Conditional(condition))
+ except AttributeError as exc:
+ raise AddConditionError(msg=str(exc), condition=condition)
+
+ def run(self):
+ while self.retries > 0:
+ self.module.cli.add_commands(self.commands)
+ responses = self.module.cli.run_commands()
+
+ for item in list(self.conditionals):
+ if item(responses):
+ if self.match == "any":
+ return item
+ self.conditionals.remove(item)
+
+ if not self.conditionals:
+ break
+
+ time.sleep(self.interval)
+ self.retries -= 1
+ else:
+ failed_conditions = [item.raw for item in self.conditionals]
+ errmsg = (
+ "One or more conditional statements have not been satisfied"
+ )
+ raise FailedConditionsError(errmsg, failed_conditions)
+
+
+class Conditional(object):
+ """Used in command modules to evaluate waitfor conditions
+ """
+
+ OPERATORS = {
+ "eq": ["eq", "=="],
+ "neq": ["neq", "ne", "!="],
+ "gt": ["gt", ">"],
+ "ge": ["ge", ">="],
+ "lt": ["lt", "<"],
+ "le": ["le", "<="],
+ "contains": ["contains"],
+ "matches": ["matches"],
+ }
+
+ def __init__(self, conditional, encoding=None):
+ self.raw = conditional
+ self.negate = False
+ try:
+ components = shlex.split(conditional)
+ key, val = components[0], components[-1]
+ op_components = components[1:-1]
+ if "not" in op_components:
+ self.negate = True
+ op_components.pop(op_components.index("not"))
+ op = op_components[0]
+
+ except ValueError:
+ raise ValueError("failed to parse conditional")
+
+ self.key = key
+ self.func = self._func(op)
+ self.value = self._cast_value(val)
+
+ def __call__(self, data):
+ value = self.get_value(dict(result=data))
+ if not self.negate:
+ return self.func(value)
+ else:
+ return not self.func(value)
+
+ def _cast_value(self, value):
+ if value in BOOLEANS_TRUE:
+ return True
+ elif value in BOOLEANS_FALSE:
+ return False
+ elif re.match(r"^\d+\.d+$", value):
+ return float(value)
+ elif re.match(r"^\d+$", value):
+ return int(value)
+ else:
+ return text_type(value)
+
+ def _func(self, oper):
+ for func, operators in self.OPERATORS.items():
+ if oper in operators:
+ return getattr(self, func)
+ raise AttributeError("unknown operator: %s" % oper)
+
+ def get_value(self, result):
+ try:
+ return self.get_json(result)
+ except (IndexError, TypeError, AttributeError):
+ msg = "unable to apply conditional to result"
+ raise FailedConditionalError(msg, self.raw)
+
+ def get_json(self, result):
+ string = re.sub(r"\[[\'|\"]", ".", self.key)
+ string = re.sub(r"[\'|\"]\]", ".", string)
+ parts = re.split(r"\.(?=[^\]]*(?:\[|$))", string)
+ for part in parts:
+ match = re.findall(r"\[(\S+?)\]", part)
+ if match:
+ key = part[: part.find("[")]
+ result = result[key]
+ for m in match:
+ try:
+ m = int(m)
+ except ValueError:
+ m = str(m)
+ result = result[m]
+ else:
+ result = result.get(part)
+ return result
+
+ def number(self, value):
+ if "." in str(value):
+ return float(value)
+ else:
+ return int(value)
+
+ def eq(self, value):
+ return value == self.value
+
+ def neq(self, value):
+ return value != self.value
+
+ def gt(self, value):
+ return self.number(value) > self.value
+
+ def ge(self, value):
+ return self.number(value) >= self.value
+
+ def lt(self, value):
+ return self.number(value) < self.value
+
+ def le(self, value):
+ return self.number(value) <= self.value
+
+ def contains(self, value):
+ return str(self.value) in value
+
+ def matches(self, value):
+ match = re.search(self.value, value, re.M)
+ return match is not None
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py
new file mode 100644
index 00000000..64eca157
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py
@@ -0,0 +1,686 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# (c) 2016 Red Hat Inc.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+# Networking tools for network modules only
+
+import re
+import ast
+import operator
+import socket
+import json
+
+from itertools import chain
+
+from ansible.module_utils._text import to_text, to_bytes
+from ansible.module_utils.common._collections_compat import Mapping
+from ansible.module_utils.six import iteritems, string_types
+from ansible.module_utils import basic
+from ansible.module_utils.parsing.convert_bool import boolean
+
+# Backwards compatibility for 3rd party modules
+# TODO(pabelanger): With move to ansible.netcommon, we should clean this code
+# up and have modules import directly themself.
+from ansible.module_utils.common.network import ( # noqa: F401
+ to_bits,
+ is_netmask,
+ is_masklen,
+ to_netmask,
+ to_masklen,
+ to_subnet,
+ to_ipv6_network,
+ VALID_MASKS,
+)
+
+try:
+ from jinja2 import Environment, StrictUndefined
+ from jinja2.exceptions import UndefinedError
+
+ HAS_JINJA2 = True
+except ImportError:
+ HAS_JINJA2 = False
+
+
+OPERATORS = frozenset(["ge", "gt", "eq", "neq", "lt", "le"])
+ALIASES = frozenset(
+ [("min", "ge"), ("max", "le"), ("exactly", "eq"), ("neq", "ne")]
+)
+
+
+def to_list(val):
+ if isinstance(val, (list, tuple, set)):
+ return list(val)
+ elif val is not None:
+ return [val]
+ else:
+ return list()
+
+
+def to_lines(stdout):
+ for item in stdout:
+ if isinstance(item, string_types):
+ item = to_text(item).split("\n")
+ yield item
+
+
+def transform_commands(module):
+ transform = ComplexList(
+ dict(
+ command=dict(key=True),
+ output=dict(),
+ prompt=dict(type="list"),
+ answer=dict(type="list"),
+ newline=dict(type="bool", default=True),
+ sendonly=dict(type="bool", default=False),
+ check_all=dict(type="bool", default=False),
+ ),
+ module,
+ )
+
+ return transform(module.params["commands"])
+
+
+def sort_list(val):
+ if isinstance(val, list):
+ return sorted(val)
+ return val
+
+
+class Entity(object):
+ """Transforms a dict to with an argument spec
+
+ This class will take a dict and apply an Ansible argument spec to the
+ values. The resulting dict will contain all of the keys in the param
+ with appropriate values set.
+
+ Example::
+
+ argument_spec = dict(
+ command=dict(key=True),
+ display=dict(default='text', choices=['text', 'json']),
+ validate=dict(type='bool')
+ )
+ transform = Entity(module, argument_spec)
+ value = dict(command='foo')
+ result = transform(value)
+ print result
+ {'command': 'foo', 'display': 'text', 'validate': None}
+
+ Supported argument spec:
+ * key - specifies how to map a single value to a dict
+ * read_from - read and apply the argument_spec from the module
+ * required - a value is required
+ * type - type of value (uses AnsibleModule type checker)
+ * fallback - implements fallback function
+ * choices - set of valid options
+ * default - default value
+ """
+
+ def __init__(
+ self, module, attrs=None, args=None, keys=None, from_argspec=False
+ ):
+ args = [] if args is None else args
+
+ self._attributes = attrs or {}
+ self._module = module
+
+ for arg in args:
+ self._attributes[arg] = dict()
+ if from_argspec:
+ self._attributes[arg]["read_from"] = arg
+ if keys and arg in keys:
+ self._attributes[arg]["key"] = True
+
+ self.attr_names = frozenset(self._attributes.keys())
+
+ _has_key = False
+
+ for name, attr in iteritems(self._attributes):
+ if attr.get("read_from"):
+ if attr["read_from"] not in self._module.argument_spec:
+ module.fail_json(
+ msg="argument %s does not exist" % attr["read_from"]
+ )
+ spec = self._module.argument_spec.get(attr["read_from"])
+ for key, value in iteritems(spec):
+ if key not in attr:
+ attr[key] = value
+
+ if attr.get("key"):
+ if _has_key:
+ module.fail_json(msg="only one key value can be specified")
+ _has_key = True
+ attr["required"] = True
+
+ def serialize(self):
+ return self._attributes
+
+ def to_dict(self, value):
+ obj = {}
+ for name, attr in iteritems(self._attributes):
+ if attr.get("key"):
+ obj[name] = value
+ else:
+ obj[name] = attr.get("default")
+ return obj
+
+ def __call__(self, value, strict=True):
+ if not isinstance(value, dict):
+ value = self.to_dict(value)
+
+ if strict:
+ unknown = set(value).difference(self.attr_names)
+ if unknown:
+ self._module.fail_json(
+ msg="invalid keys: %s" % ",".join(unknown)
+ )
+
+ for name, attr in iteritems(self._attributes):
+ if value.get(name) is None:
+ value[name] = attr.get("default")
+
+ if attr.get("fallback") and not value.get(name):
+ fallback = attr.get("fallback", (None,))
+ fallback_strategy = fallback[0]
+ fallback_args = []
+ fallback_kwargs = {}
+ if fallback_strategy is not None:
+ for item in fallback[1:]:
+ if isinstance(item, dict):
+ fallback_kwargs = item
+ else:
+ fallback_args = item
+ try:
+ value[name] = fallback_strategy(
+ *fallback_args, **fallback_kwargs
+ )
+ except basic.AnsibleFallbackNotFound:
+ continue
+
+ if attr.get("required") and value.get(name) is None:
+ self._module.fail_json(
+ msg="missing required attribute %s" % name
+ )
+
+ if "choices" in attr:
+ if value[name] not in attr["choices"]:
+ self._module.fail_json(
+ msg="%s must be one of %s, got %s"
+ % (name, ", ".join(attr["choices"]), value[name])
+ )
+
+ if value[name] is not None:
+ value_type = attr.get("type", "str")
+ type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[
+ value_type
+ ]
+ type_checker(value[name])
+ elif value.get(name):
+ value[name] = self._module.params[name]
+
+ return value
+
+
+class EntityCollection(Entity):
+ """Extends ```Entity``` to handle a list of dicts """
+
+ def __call__(self, iterable, strict=True):
+ if iterable is None:
+ iterable = [
+ super(EntityCollection, self).__call__(
+ self._module.params, strict
+ )
+ ]
+
+ if not isinstance(iterable, (list, tuple)):
+ self._module.fail_json(msg="value must be an iterable")
+
+ return [
+ (super(EntityCollection, self).__call__(i, strict))
+ for i in iterable
+ ]
+
+
+# these two are for backwards compatibility and can be removed once all of the
+# modules that use them are updated
+class ComplexDict(Entity):
+ def __init__(self, attrs, module, *args, **kwargs):
+ super(ComplexDict, self).__init__(module, attrs, *args, **kwargs)
+
+
+class ComplexList(EntityCollection):
+ def __init__(self, attrs, module, *args, **kwargs):
+ super(ComplexList, self).__init__(module, attrs, *args, **kwargs)
+
+
+def dict_diff(base, comparable):
+ """ Generate a dict object of differences
+
+ This function will compare two dict objects and return the difference
+ between them as a dict object. For scalar values, the key will reflect
+ the updated value. If the key does not exist in `comparable`, then then no
+ key will be returned. For lists, the value in comparable will wholly replace
+ the value in base for the key. For dicts, the returned value will only
+ return keys that are different.
+
+ :param base: dict object to base the diff on
+ :param comparable: dict object to compare against base
+
+ :returns: new dict object with differences
+ """
+ if not isinstance(base, dict):
+ raise AssertionError("`base` must be of type <dict>")
+ if not isinstance(comparable, dict):
+ if comparable is None:
+ comparable = dict()
+ else:
+ raise AssertionError("`comparable` must be of type <dict>")
+
+ updates = dict()
+
+ for key, value in iteritems(base):
+ if isinstance(value, dict):
+ item = comparable.get(key)
+ if item is not None:
+ sub_diff = dict_diff(value, comparable[key])
+ if sub_diff:
+ updates[key] = sub_diff
+ else:
+ comparable_value = comparable.get(key)
+ if comparable_value is not None:
+ if sort_list(base[key]) != sort_list(comparable_value):
+ updates[key] = comparable_value
+
+ for key in set(comparable.keys()).difference(base.keys()):
+ updates[key] = comparable.get(key)
+
+ return updates
+
+
+def dict_merge(base, other):
+ """ Return a new dict object that combines base and other
+
+ This will create a new dict object that is a combination of the key/value
+ pairs from base and other. When both keys exist, the value will be
+ selected from other. If the value is a list object, the two lists will
+ be combined and duplicate entries removed.
+
+ :param base: dict object to serve as base
+ :param other: dict object to combine with base
+
+ :returns: new combined dict object
+ """
+ if not isinstance(base, dict):
+ raise AssertionError("`base` must be of type <dict>")
+ if not isinstance(other, dict):
+ raise AssertionError("`other` must be of type <dict>")
+
+ combined = dict()
+
+ for key, value in iteritems(base):
+ if isinstance(value, dict):
+ if key in other:
+ item = other.get(key)
+ if item is not None:
+ if isinstance(other[key], Mapping):
+ combined[key] = dict_merge(value, other[key])
+ else:
+ combined[key] = other[key]
+ else:
+ combined[key] = item
+ else:
+ combined[key] = value
+ elif isinstance(value, list):
+ if key in other:
+ item = other.get(key)
+ if item is not None:
+ try:
+ combined[key] = list(set(chain(value, item)))
+ except TypeError:
+ value.extend([i for i in item if i not in value])
+ combined[key] = value
+ else:
+ combined[key] = item
+ else:
+ combined[key] = value
+ else:
+ if key in other:
+ other_value = other.get(key)
+ if other_value is not None:
+ if sort_list(base[key]) != sort_list(other_value):
+ combined[key] = other_value
+ else:
+ combined[key] = value
+ else:
+ combined[key] = other_value
+ else:
+ combined[key] = value
+
+ for key in set(other.keys()).difference(base.keys()):
+ combined[key] = other.get(key)
+
+ return combined
+
+
+def param_list_to_dict(param_list, unique_key="name", remove_key=True):
+ """Rotates a list of dictionaries to be a dictionary of dictionaries.
+
+ :param param_list: The aforementioned list of dictionaries
+ :param unique_key: The name of a key which is present and unique in all of param_list's dictionaries. The value
+ behind this key will be the key each dictionary can be found at in the new root dictionary
+ :param remove_key: If True, remove unique_key from the individual dictionaries before returning.
+ """
+ param_dict = {}
+ for params in param_list:
+ params = params.copy()
+ if remove_key:
+ name = params.pop(unique_key)
+ else:
+ name = params.get(unique_key)
+ param_dict[name] = params
+
+ return param_dict
+
+
+def conditional(expr, val, cast=None):
+ match = re.match(r"^(.+)\((.+)\)$", str(expr), re.I)
+ if match:
+ op, arg = match.groups()
+ else:
+ op = "eq"
+ if " " in str(expr):
+ raise AssertionError("invalid expression: cannot contain spaces")
+ arg = expr
+
+ if cast is None and val is not None:
+ arg = type(val)(arg)
+ elif callable(cast):
+ arg = cast(arg)
+ val = cast(val)
+
+ op = next((oper for alias, oper in ALIASES if op == alias), op)
+
+ if not hasattr(operator, op) and op not in OPERATORS:
+ raise ValueError("unknown operator: %s" % op)
+
+ func = getattr(operator, op)
+ return func(val, arg)
+
+
+def ternary(value, true_val, false_val):
+ """ value ? true_val : false_val """
+ if value:
+ return true_val
+ else:
+ return false_val
+
+
+def remove_default_spec(spec):
+ for item in spec:
+ if "default" in spec[item]:
+ del spec[item]["default"]
+
+
+def validate_ip_address(address):
+ try:
+ socket.inet_aton(address)
+ except socket.error:
+ return False
+ return address.count(".") == 3
+
+
+def validate_ip_v6_address(address):
+ try:
+ socket.inet_pton(socket.AF_INET6, address)
+ except socket.error:
+ return False
+ return True
+
+
+def validate_prefix(prefix):
+ if prefix and not 0 <= int(prefix) <= 32:
+ return False
+ return True
+
+
+def load_provider(spec, args):
+ provider = args.get("provider") or {}
+ for key, value in iteritems(spec):
+ if key not in provider:
+ if "fallback" in value:
+ provider[key] = _fallback(value["fallback"])
+ elif "default" in value:
+ provider[key] = value["default"]
+ else:
+ provider[key] = None
+ if "authorize" in provider:
+ # Coerce authorize to provider if a string has somehow snuck in.
+ provider["authorize"] = boolean(provider["authorize"] or False)
+ args["provider"] = provider
+ return provider
+
+
+def _fallback(fallback):
+ strategy = fallback[0]
+ args = []
+ kwargs = {}
+
+ for item in fallback[1:]:
+ if isinstance(item, dict):
+ kwargs = item
+ else:
+ args = item
+ try:
+ return strategy(*args, **kwargs)
+ except basic.AnsibleFallbackNotFound:
+ pass
+
+
+def generate_dict(spec):
+ """
+ Generate dictionary which is in sync with argspec
+
+ :param spec: A dictionary that is the argspec of the module
+ :rtype: A dictionary
+ :returns: A dictionary in sync with argspec with default value
+ """
+ obj = {}
+ if not spec:
+ return obj
+
+ for key, val in iteritems(spec):
+ if "default" in val:
+ dct = {key: val["default"]}
+ elif "type" in val and val["type"] == "dict":
+ dct = {key: generate_dict(val["options"])}
+ else:
+ dct = {key: None}
+ obj.update(dct)
+ return obj
+
+
+def parse_conf_arg(cfg, arg):
+ """
+ Parse config based on argument
+
+ :param cfg: A text string which is a line of configuration.
+ :param arg: A text string which is to be matched.
+ :rtype: A text string
+ :returns: A text string if match is found
+ """
+ match = re.search(r"%s (.+)(\n|$)" % arg, cfg, re.M)
+ if match:
+ result = match.group(1).strip()
+ else:
+ result = None
+ return result
+
+
+def parse_conf_cmd_arg(cfg, cmd, res1, res2=None, delete_str="no"):
+ """
+ Parse config based on command
+
+ :param cfg: A text string which is a line of configuration.
+ :param cmd: A text string which is the command to be matched
+ :param res1: A text string to be returned if the command is present
+ :param res2: A text string to be returned if the negate command
+ is present
+ :param delete_str: A text string to identify the start of the
+ negate command
+ :rtype: A text string
+ :returns: A text string if match is found
+ """
+ match = re.search(r"\n\s+%s(\n|$)" % cmd, cfg)
+ if match:
+ return res1
+ if res2 is not None:
+ match = re.search(r"\n\s+%s %s(\n|$)" % (delete_str, cmd), cfg)
+ if match:
+ return res2
+ return None
+
+
+def get_xml_conf_arg(cfg, path, data="text"):
+ """
+ :param cfg: The top level configuration lxml Element tree object
+ :param path: The relative xpath w.r.t to top level element (cfg)
+ to be searched in the xml hierarchy
+ :param data: The type of data to be returned for the matched xml node.
+ Valid values are text, tag, attrib, with default as text.
+ :return: Returns the required type for the matched xml node or else None
+ """
+ match = cfg.xpath(path)
+ if len(match):
+ if data == "tag":
+ result = getattr(match[0], "tag")
+ elif data == "attrib":
+ result = getattr(match[0], "attrib")
+ else:
+ result = getattr(match[0], "text")
+ else:
+ result = None
+ return result
+
+
+def remove_empties(cfg_dict):
+ """
+ Generate final config dictionary
+
+ :param cfg_dict: A dictionary parsed in the facts system
+ :rtype: A dictionary
+ :returns: A dictionary by eliminating keys that have null values
+ """
+ final_cfg = {}
+ if not cfg_dict:
+ return final_cfg
+
+ for key, val in iteritems(cfg_dict):
+ dct = None
+ if isinstance(val, dict):
+ child_val = remove_empties(val)
+ if child_val:
+ dct = {key: child_val}
+ elif (
+ isinstance(val, list)
+ and val
+ and all([isinstance(x, dict) for x in val])
+ ):
+ child_val = [remove_empties(x) for x in val]
+ if child_val:
+ dct = {key: child_val}
+ elif val not in [None, [], {}, (), ""]:
+ dct = {key: val}
+ if dct:
+ final_cfg.update(dct)
+ return final_cfg
+
+
+def validate_config(spec, data):
+ """
+ Validate if the input data against the AnsibleModule spec format
+ :param spec: Ansible argument spec
+ :param data: Data to be validated
+ :return:
+ """
+ params = basic._ANSIBLE_ARGS
+ basic._ANSIBLE_ARGS = to_bytes(json.dumps({"ANSIBLE_MODULE_ARGS": data}))
+ validated_data = basic.AnsibleModule(spec).params
+ basic._ANSIBLE_ARGS = params
+ return validated_data
+
+
+def search_obj_in_list(name, lst, key="name"):
+ if not lst:
+ return None
+ else:
+ for item in lst:
+ if item.get(key) == name:
+ return item
+
+
+class Template:
+ def __init__(self):
+ if not HAS_JINJA2:
+ raise ImportError(
+ "jinja2 is required but does not appear to be installed. "
+ "It can be installed using `pip install jinja2`"
+ )
+
+ self.env = Environment(undefined=StrictUndefined)
+ self.env.filters.update({"ternary": ternary})
+
+ def __call__(self, value, variables=None, fail_on_undefined=True):
+ variables = variables or {}
+
+ if not self.contains_vars(value):
+ return value
+
+ try:
+ value = self.env.from_string(value).render(variables)
+ except UndefinedError:
+ if not fail_on_undefined:
+ return None
+ raise
+
+ if value:
+ try:
+ return ast.literal_eval(value)
+ except Exception:
+ return str(value)
+ else:
+ return None
+
+ def contains_vars(self, data):
+ if isinstance(data, string_types):
+ for marker in (
+ self.env.block_start_string,
+ self.env.variable_start_string,
+ self.env.comment_start_string,
+ ):
+ if marker in data:
+ return True
+ return False
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py
new file mode 100644
index 00000000..1f03299b
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py
@@ -0,0 +1,147 @@
+#
+# (c) 2018 Red Hat, Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+import json
+
+from copy import deepcopy
+from contextlib import contextmanager
+
+try:
+ from lxml.etree import fromstring, tostring
+except ImportError:
+ from xml.etree.ElementTree import fromstring, tostring
+
+from ansible.module_utils._text import to_text, to_bytes
+from ansible.module_utils.connection import Connection, ConnectionError
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.netconf import (
+ NetconfConnection,
+)
+
+
+IGNORE_XML_ATTRIBUTE = ()
+
+
+def get_connection(module):
+ if hasattr(module, "_netconf_connection"):
+ return module._netconf_connection
+
+ capabilities = get_capabilities(module)
+ network_api = capabilities.get("network_api")
+ if network_api == "netconf":
+ module._netconf_connection = NetconfConnection(module._socket_path)
+ else:
+ module.fail_json(msg="Invalid connection type %s" % network_api)
+
+ return module._netconf_connection
+
+
+def get_capabilities(module):
+ if hasattr(module, "_netconf_capabilities"):
+ return module._netconf_capabilities
+
+ capabilities = Connection(module._socket_path).get_capabilities()
+ module._netconf_capabilities = json.loads(capabilities)
+ return module._netconf_capabilities
+
+
+def lock_configuration(module, target=None):
+ conn = get_connection(module)
+ return conn.lock(target=target)
+
+
+def unlock_configuration(module, target=None):
+ conn = get_connection(module)
+ return conn.unlock(target=target)
+
+
+@contextmanager
+def locked_config(module, target=None):
+ try:
+ lock_configuration(module, target=target)
+ yield
+ finally:
+ unlock_configuration(module, target=target)
+
+
+def get_config(module, source, filter=None, lock=False):
+ conn = get_connection(module)
+ try:
+ locked = False
+ if lock:
+ conn.lock(target=source)
+ locked = True
+ response = conn.get_config(source=source, filter=filter)
+
+ except ConnectionError as e:
+ module.fail_json(
+ msg=to_text(e, errors="surrogate_then_replace").strip()
+ )
+
+ finally:
+ if locked:
+ conn.unlock(target=source)
+
+ return response
+
+
+def get(module, filter, lock=False):
+ conn = get_connection(module)
+ try:
+ locked = False
+ if lock:
+ conn.lock(target="running")
+ locked = True
+
+ response = conn.get(filter=filter)
+
+ except ConnectionError as e:
+ module.fail_json(
+ msg=to_text(e, errors="surrogate_then_replace").strip()
+ )
+
+ finally:
+ if locked:
+ conn.unlock(target="running")
+
+ return response
+
+
+def dispatch(module, request):
+ conn = get_connection(module)
+ try:
+ response = conn.dispatch(request)
+ except ConnectionError as e:
+ module.fail_json(
+ msg=to_text(e, errors="surrogate_then_replace").strip()
+ )
+
+ return response
+
+
+def sanitize_xml(data):
+ tree = fromstring(
+ to_bytes(deepcopy(data), errors="surrogate_then_replace")
+ )
+ for element in tree.getiterator():
+ # remove attributes
+ attribute = element.attrib
+ if attribute:
+ for key in list(attribute):
+ if key not in IGNORE_XML_ATTRIBUTE:
+ attribute.pop(key)
+ return to_text(tostring(tree), errors="surrogate_then_replace").strip()
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py
new file mode 100644
index 00000000..fba46be0
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py
@@ -0,0 +1,61 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# (c) 2018 Red Hat Inc.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+from ansible.module_utils.connection import Connection
+
+
+def get(module, path=None, content=None, fields=None, output="json"):
+ if path is None:
+ raise ValueError("path value must be provided")
+ if content:
+ path += "?" + "content=%s" % content
+ if fields:
+ path += "?" + "field=%s" % fields
+
+ accept = None
+ if output == "xml":
+ accept = "application/yang-data+xml"
+
+ connection = Connection(module._socket_path)
+ return connection.send_request(
+ None, path=path, method="GET", accept=accept
+ )
+
+
+def edit_config(module, path=None, content=None, method="GET", format="json"):
+ if path is None:
+ raise ValueError("path value must be provided")
+
+ content_type = None
+ if format == "xml":
+ content_type = "application/yang-data+xml"
+
+ connection = Connection(module._socket_path)
+ return connection.send_request(
+ content, path=path, method=method, content_type=content_type
+ )
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py
new file mode 100644
index 00000000..c1384c1d
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py
@@ -0,0 +1,444 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# (c) 2018, Ansible by Red Hat, inc
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {
+ "metadata_version": "1.1",
+ "status": ["preview"],
+ "supported_by": "network",
+}
+
+
+DOCUMENTATION = """module: cli_config
+author: Trishna Guha (@trishnaguha)
+notes:
+- The commands will be returned only for platforms that do not support onbox diff.
+ The C(--diff) option with the playbook will return the difference in configuration
+ for devices that has support for onbox diff
+short_description: Push text based configuration to network devices over network_cli
+description:
+- This module provides platform agnostic way of pushing text based configuration to
+ network devices over network_cli connection plugin.
+extends_documentation_fragment:
+- ansible.netcommon.network_agnostic
+options:
+ config:
+ description:
+ - The config to be pushed to the network device. This argument is mutually exclusive
+ with C(rollback) and either one of the option should be given as input. The
+ config should have indentation that the device uses.
+ type: str
+ commit:
+ description:
+ - The C(commit) argument instructs the module to push the configuration to the
+ device. This is mapped to module check mode.
+ type: bool
+ replace:
+ description:
+ - If the C(replace) argument is set to C(yes), it will replace the entire running-config
+ of the device with the C(config) argument value. For devices that support replacing
+ running configuration from file on device like NXOS/JUNOS, the C(replace) argument
+ takes path to the file on the device that will be used for replacing the entire
+ running-config. The value of C(config) option should be I(None) for such devices.
+ Nexus 9K devices only support replace. Use I(net_put) or I(nxos_file_copy) in
+ case of NXOS module to copy the flat file to remote device and then use set
+ the fullpath to this argument.
+ type: str
+ backup:
+ description:
+ - This argument will cause the module to create a full backup of the current running
+ config from the remote device before any changes are made. If the C(backup_options)
+ value is not given, the backup file is written to the C(backup) folder in the
+ playbook root directory or role root directory, if playbook is part of an ansible
+ role. If the directory does not exist, it is created.
+ type: bool
+ default: 'no'
+ rollback:
+ description:
+ - The C(rollback) argument instructs the module to rollback the current configuration
+ to the identifier specified in the argument. If the specified rollback identifier
+ does not exist on the remote device, the module will fail. To rollback to the
+ most recent commit, set the C(rollback) argument to 0. This option is mutually
+ exclusive with C(config).
+ commit_comment:
+ description:
+ - The C(commit_comment) argument specifies a text string to be used when committing
+ the configuration. If the C(commit) argument is set to False, this argument
+ is silently ignored. This argument is only valid for the platforms that support
+ commit operation with comment.
+ type: str
+ defaults:
+ description:
+ - The I(defaults) argument will influence how the running-config is collected
+ from the device. When the value is set to true, the command used to collect
+ the running-config is append with the all keyword. When the value is set to
+ false, the command is issued without the all keyword.
+ default: 'no'
+ type: bool
+ multiline_delimiter:
+ description:
+ - This argument is used when pushing a multiline configuration element to the
+ device. It specifies the character to use as the delimiting character. This
+ only applies to the configuration action.
+ type: str
+ diff_replace:
+ description:
+ - Instructs the module on the way to perform the configuration on the device.
+ If the C(diff_replace) argument is set to I(line) then the modified lines are
+ pushed to the device in configuration mode. If the argument is set to I(block)
+ then the entire command block is pushed to the device in configuration mode
+ if any line is not correct. Note that this parameter will be ignored if the
+ platform has onbox diff support.
+ choices:
+ - line
+ - block
+ - config
+ diff_match:
+ description:
+ - Instructs the module on the way to perform the matching of the set of commands
+ against the current device config. If C(diff_match) is set to I(line), commands
+ are matched line by line. If C(diff_match) is set to I(strict), command lines
+ are matched with respect to position. If C(diff_match) is set to I(exact), command
+ lines must be an equal match. Finally, if C(diff_match) is set to I(none), the
+ module will not attempt to compare the source configuration with the running
+ configuration on the remote device. Note that this parameter will be ignored
+ if the platform has onbox diff support.
+ choices:
+ - line
+ - strict
+ - exact
+ - none
+ diff_ignore_lines:
+ description:
+ - Use this argument to specify one or more lines that should be ignored during
+ the diff. This is used for lines in the configuration that are automatically
+ updated by the system. This argument takes a list of regular expressions or
+ exact line matches. Note that this parameter will be ignored if the platform
+ has onbox diff support.
+ backup_options:
+ description:
+ - This is a dict object containing configurable options related to backup file
+ path. The value of this option is read only when C(backup) is set to I(yes),
+ if C(backup) is set to I(no) this option will be silently ignored.
+ suboptions:
+ filename:
+ description:
+ - The filename to be used to store the backup configuration. If the filename
+ is not given it will be generated based on the hostname, current time and
+ date in format defined by <hostname>_config.<current-date>@<current-time>
+ dir_path:
+ description:
+ - This option provides the path ending with directory name in which the backup
+ configuration file will be stored. If the directory does not exist it will
+ be first created and the filename is either the value of C(filename) or
+ default filename as described in C(filename) options description. If the
+ path value is not given in that case a I(backup) directory will be created
+ in the current working directory and backup configuration will be copied
+ in C(filename) within I(backup) directory.
+ type: path
+ type: dict
+"""
+
+EXAMPLES = """
+- name: configure device with config
+ cli_config:
+ config: "{{ lookup('template', 'basic/config.j2') }}"
+
+- name: multiline config
+ cli_config:
+ config: |
+ hostname foo
+ feature nxapi
+
+- name: configure device with config with defaults enabled
+ cli_config:
+ config: "{{ lookup('template', 'basic/config.j2') }}"
+ defaults: yes
+
+- name: Use diff_match
+ cli_config:
+ config: "{{ lookup('file', 'interface_config') }}"
+ diff_match: none
+
+- name: nxos replace config
+ cli_config:
+ replace: 'bootflash:nxoscfg'
+
+- name: junos replace config
+ cli_config:
+ replace: '/var/home/ansible/junos01.cfg'
+
+- name: commit with comment
+ cli_config:
+ config: set system host-name foo
+ commit_comment: this is a test
+
+- name: configurable backup path
+ cli_config:
+ config: "{{ lookup('template', 'basic/config.j2') }}"
+ backup: yes
+ backup_options:
+ filename: backup.cfg
+ dir_path: /home/user
+"""
+
+RETURN = """
+commands:
+ description: The set of commands that will be pushed to the remote device
+ returned: always
+ type: list
+ sample: ['interface Loopback999', 'no shutdown']
+backup_path:
+ description: The full path to the backup file
+ returned: when backup is yes
+ type: str
+ sample: /playbooks/ansible/backup/hostname_config.2016-07-16@22:28:34
+"""
+
+import json
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.connection import Connection
+from ansible.module_utils._text import to_text
+
+
+def validate_args(module, device_operations):
+ """validate param if it is supported on the platform
+ """
+ feature_list = [
+ "replace",
+ "rollback",
+ "commit_comment",
+ "defaults",
+ "multiline_delimiter",
+ "diff_replace",
+ "diff_match",
+ "diff_ignore_lines",
+ ]
+
+ for feature in feature_list:
+ if module.params[feature]:
+ supports_feature = device_operations.get("supports_%s" % feature)
+ if supports_feature is None:
+ module.fail_json(
+ "This platform does not specify whether %s is supported or not. "
+ "Please report an issue against this platform's cliconf plugin."
+ % feature
+ )
+ elif not supports_feature:
+ module.fail_json(
+ msg="Option %s is not supported on this platform" % feature
+ )
+
+
+def run(
+ module, device_operations, connection, candidate, running, rollback_id
+):
+ result = {}
+ resp = {}
+ config_diff = []
+ banner_diff = {}
+
+ replace = module.params["replace"]
+ commit_comment = module.params["commit_comment"]
+ multiline_delimiter = module.params["multiline_delimiter"]
+ diff_replace = module.params["diff_replace"]
+ diff_match = module.params["diff_match"]
+ diff_ignore_lines = module.params["diff_ignore_lines"]
+
+ commit = not module.check_mode
+
+ if replace in ("yes", "true", "True"):
+ replace = True
+ elif replace in ("no", "false", "False"):
+ replace = False
+
+ if (
+ replace is not None
+ and replace not in [True, False]
+ and candidate is not None
+ ):
+ module.fail_json(
+ msg="Replace value '%s' is a configuration file path already"
+ " present on the device. Hence 'replace' and 'config' options"
+ " are mutually exclusive" % replace
+ )
+
+ if rollback_id is not None:
+ resp = connection.rollback(rollback_id, commit)
+ if "diff" in resp:
+ result["changed"] = True
+
+ elif device_operations.get("supports_onbox_diff"):
+ if diff_replace:
+ module.warn(
+ "diff_replace is ignored as the device supports onbox diff"
+ )
+ if diff_match:
+ module.warn(
+ "diff_mattch is ignored as the device supports onbox diff"
+ )
+ if diff_ignore_lines:
+ module.warn(
+ "diff_ignore_lines is ignored as the device supports onbox diff"
+ )
+
+ if candidate and not isinstance(candidate, list):
+ candidate = candidate.strip("\n").splitlines()
+
+ kwargs = {
+ "candidate": candidate,
+ "commit": commit,
+ "replace": replace,
+ "comment": commit_comment,
+ }
+ resp = connection.edit_config(**kwargs)
+
+ if "diff" in resp:
+ result["changed"] = True
+
+ elif device_operations.get("supports_generate_diff"):
+ kwargs = {"candidate": candidate, "running": running}
+ if diff_match:
+ kwargs.update({"diff_match": diff_match})
+ if diff_replace:
+ kwargs.update({"diff_replace": diff_replace})
+ if diff_ignore_lines:
+ kwargs.update({"diff_ignore_lines": diff_ignore_lines})
+
+ diff_response = connection.get_diff(**kwargs)
+
+ config_diff = diff_response.get("config_diff")
+ banner_diff = diff_response.get("banner_diff")
+
+ if config_diff:
+ if isinstance(config_diff, list):
+ candidate = config_diff
+ else:
+ candidate = config_diff.splitlines()
+
+ kwargs = {
+ "candidate": candidate,
+ "commit": commit,
+ "replace": replace,
+ "comment": commit_comment,
+ }
+ if commit:
+ connection.edit_config(**kwargs)
+ result["changed"] = True
+ result["commands"] = config_diff.split("\n")
+
+ if banner_diff:
+ candidate = json.dumps(banner_diff)
+
+ kwargs = {"candidate": candidate, "commit": commit}
+ if multiline_delimiter:
+ kwargs.update({"multiline_delimiter": multiline_delimiter})
+ if commit:
+ connection.edit_banner(**kwargs)
+ result["changed"] = True
+
+ if module._diff:
+ if "diff" in resp:
+ result["diff"] = {"prepared": resp["diff"]}
+ else:
+ diff = ""
+ if config_diff:
+ if isinstance(config_diff, list):
+ diff += "\n".join(config_diff)
+ else:
+ diff += config_diff
+ if banner_diff:
+ diff += json.dumps(banner_diff)
+ result["diff"] = {"prepared": diff}
+
+ return result
+
+
+def main():
+ """main entry point for execution
+ """
+ backup_spec = dict(filename=dict(), dir_path=dict(type="path"))
+ argument_spec = dict(
+ backup=dict(default=False, type="bool"),
+ backup_options=dict(type="dict", options=backup_spec),
+ config=dict(type="str"),
+ commit=dict(type="bool"),
+ replace=dict(type="str"),
+ rollback=dict(type="int"),
+ commit_comment=dict(type="str"),
+ defaults=dict(default=False, type="bool"),
+ multiline_delimiter=dict(type="str"),
+ diff_replace=dict(choices=["line", "block", "config"]),
+ diff_match=dict(choices=["line", "strict", "exact", "none"]),
+ diff_ignore_lines=dict(type="list"),
+ )
+
+ mutually_exclusive = [("config", "rollback")]
+ required_one_of = [["backup", "config", "rollback"]]
+
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ mutually_exclusive=mutually_exclusive,
+ required_one_of=required_one_of,
+ supports_check_mode=True,
+ )
+
+ result = {"changed": False}
+
+ connection = Connection(module._socket_path)
+ capabilities = module.from_json(connection.get_capabilities())
+
+ if capabilities:
+ device_operations = capabilities.get("device_operations", dict())
+ validate_args(module, device_operations)
+ else:
+ device_operations = dict()
+
+ if module.params["defaults"]:
+ if "get_default_flag" in capabilities.get("rpc"):
+ flags = connection.get_default_flag()
+ else:
+ flags = "all"
+ else:
+ flags = []
+
+ candidate = module.params["config"]
+ candidate = (
+ to_text(candidate, errors="surrogate_then_replace")
+ if candidate
+ else None
+ )
+ running = connection.get_config(flags=flags)
+ rollback_id = module.params["rollback"]
+
+ if module.params["backup"]:
+ result["__backup__"] = running
+
+ if candidate or rollback_id or module.params["replace"]:
+ try:
+ result.update(
+ run(
+ module,
+ device_operations,
+ connection,
+ candidate,
+ running,
+ rollback_id,
+ )
+ )
+ except Exception as exc:
+ module.fail_json(msg=to_text(exc))
+
+ module.exit_json(**result)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py
new file mode 100644
index 00000000..f0910f52
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py
@@ -0,0 +1,71 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# (c) 2018, Ansible by Red Hat, inc
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {
+ "metadata_version": "1.1",
+ "status": ["preview"],
+ "supported_by": "network",
+}
+
+
+DOCUMENTATION = """module: net_get
+author: Deepak Agrawal (@dagrawal)
+short_description: Copy a file from a network device to Ansible Controller
+description:
+- This module provides functionality to copy file from network device to ansible controller.
+extends_documentation_fragment:
+- ansible.netcommon.network_agnostic
+options:
+ src:
+ description:
+ - Specifies the source file. The path to the source file can either be the full
+ path on the network device or a relative path as per path supported by destination
+ network device.
+ required: true
+ protocol:
+ description:
+ - Protocol used to transfer file.
+ default: scp
+ choices:
+ - scp
+ - sftp
+ dest:
+ description:
+ - Specifies the destination file. The path to the destination file can either
+ be the full path on the Ansible control host or a relative path from the playbook
+ or role root directory.
+ default:
+ - Same filename as specified in I(src). The path will be playbook root or role
+ root directory if playbook is part of a role.
+requirements:
+- scp
+notes:
+- Some devices need specific configurations to be enabled before scp can work These
+ configuration should be pre-configured before using this module e.g ios - C(ip scp
+ server enable).
+- User privilege to do scp on network device should be pre-configured e.g. ios - need
+ user privilege 15 by default for allowing scp.
+- Default destination of source file.
+"""
+
+EXAMPLES = """
+- name: copy file from the network device to Ansible controller
+ net_get:
+ src: running_cfg_ios1.txt
+
+- name: copy file from ios to common location at /tmp
+ net_get:
+ src: running_cfg_sw1.txt
+ dest : /tmp/ios1.txt
+"""
+
+RETURN = """
+"""
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py
new file mode 100644
index 00000000..2fc4a98c
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py
@@ -0,0 +1,82 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# (c) 2018, Ansible by Red Hat, inc
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {
+ "metadata_version": "1.1",
+ "status": ["preview"],
+ "supported_by": "network",
+}
+
+
+DOCUMENTATION = """module: net_put
+author: Deepak Agrawal (@dagrawal)
+short_description: Copy a file from Ansible Controller to a network device
+description:
+- This module provides functionality to copy file from Ansible controller to network
+ devices.
+extends_documentation_fragment:
+- ansible.netcommon.network_agnostic
+options:
+ src:
+ description:
+ - Specifies the source file. The path to the source file can either be the full
+ path on the Ansible control host or a relative path from the playbook or role
+ root directory.
+ required: true
+ protocol:
+ description:
+ - Protocol used to transfer file.
+ default: scp
+ choices:
+ - scp
+ - sftp
+ dest:
+ description:
+ - Specifies the destination file. The path to destination file can either be the
+ full path or relative path as supported by network_os.
+ default:
+ - Filename from src and at default directory of user shell on network_os.
+ required: false
+ mode:
+ description:
+ - Set the file transfer mode. If mode is set to I(text) then I(src) file will
+ go through Jinja2 template engine to replace any vars if present in the src
+ file. If mode is set to I(binary) then file will be copied as it is to destination
+ device.
+ default: binary
+ choices:
+ - binary
+ - text
+requirements:
+- scp
+notes:
+- Some devices need specific configurations to be enabled before scp can work These
+ configuration should be pre-configured before using this module e.g ios - C(ip scp
+ server enable).
+- User privilege to do scp on network device should be pre-configured e.g. ios - need
+ user privilege 15 by default for allowing scp.
+- Default destination of source file.
+"""
+
+EXAMPLES = """
+- name: copy file from ansible controller to a network device
+ net_put:
+ src: running_cfg_ios1.txt
+
+- name: copy file at root dir of flash in slot 3 of sw1(ios)
+ net_put:
+ src: running_cfg_sw1.txt
+ protocol: sftp
+ dest : flash3:/running_cfg_sw1.txt
+"""
+
+RETURN = """
+"""
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py
new file mode 100644
index 00000000..e9332f26
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py
@@ -0,0 +1,70 @@
+#
+# (c) 2017 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+DOCUMENTATION = """author: Ansible Networking Team
+netconf: default
+short_description: Use default netconf plugin to run standard netconf commands as
+ per RFC
+description:
+- This default plugin provides low level abstraction apis for sending and receiving
+ netconf commands as per Netconf RFC specification.
+options:
+ ncclient_device_handler:
+ type: str
+ default: default
+ description:
+ - Specifies the ncclient device handler name for network os that support default
+ netconf implementation as per Netconf RFC specification. To identify the ncclient
+ device handler name refer ncclient library documentation.
+"""
+import json
+
+from ansible.module_utils._text import to_text
+from ansible.plugins.netconf import NetconfBase
+
+
+class Netconf(NetconfBase):
+ def get_text(self, ele, tag):
+ try:
+ return to_text(
+ ele.find(tag).text, errors="surrogate_then_replace"
+ ).strip()
+ except AttributeError:
+ pass
+
+ def get_device_info(self):
+ device_info = dict()
+ device_info["network_os"] = "default"
+ return device_info
+
+ def get_capabilities(self):
+ result = dict()
+ result["rpc"] = self.get_base_rpc()
+ result["network_api"] = "netconf"
+ result["device_info"] = self.get_device_info()
+ result["server_capabilities"] = [c for c in self.m.server_capabilities]
+ result["client_capabilities"] = [c for c in self.m.client_capabilities]
+ result["session_id"] = self.m.session_id
+ result["device_operations"] = self.get_device_operations(
+ result["server_capabilities"]
+ )
+ return json.dumps(result)
diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py
new file mode 100644
index 00000000..e5ac2cd1
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py
@@ -0,0 +1,133 @@
+#
+# (c) 2016 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import sys
+import copy
+
+from ansible_collections.ansible.netcommon.plugins.action.network import (
+ ActionModule as ActionNetworkModule,
+)
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ load_provider,
+)
+from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import (
+ ios_provider_spec,
+)
+from ansible.utils.display import Display
+
+display = Display()
+
+
+class ActionModule(ActionNetworkModule):
+ def run(self, tmp=None, task_vars=None):
+ del tmp # tmp no longer has any effect
+
+ module_name = self._task.action.split(".")[-1]
+ self._config_module = True if module_name == "ios_config" else False
+ persistent_connection = self._play_context.connection.split(".")[-1]
+ warnings = []
+
+ if persistent_connection == "network_cli":
+ provider = self._task.args.get("provider", {})
+ if any(provider.values()):
+ display.warning(
+ "provider is unnecessary when using network_cli and will be ignored"
+ )
+ del self._task.args["provider"]
+ elif self._play_context.connection == "local":
+ provider = load_provider(ios_provider_spec, self._task.args)
+ pc = copy.deepcopy(self._play_context)
+ pc.connection = "ansible.netcommon.network_cli"
+ pc.network_os = "cisco.ios.ios"
+ pc.remote_addr = provider["host"] or self._play_context.remote_addr
+ pc.port = int(provider["port"] or self._play_context.port or 22)
+ pc.remote_user = (
+ provider["username"] or self._play_context.connection_user
+ )
+ pc.password = provider["password"] or self._play_context.password
+ pc.private_key_file = (
+ provider["ssh_keyfile"] or self._play_context.private_key_file
+ )
+ pc.become = provider["authorize"] or False
+ if pc.become:
+ pc.become_method = "enable"
+ pc.become_pass = provider["auth_pass"]
+
+ connection = self._shared_loader_obj.connection_loader.get(
+ "ansible.netcommon.persistent",
+ pc,
+ sys.stdin,
+ task_uuid=self._task._uuid,
+ )
+
+ # TODO: Remove below code after ansible minimal is cut out
+ if connection is None:
+ pc.connection = "network_cli"
+ pc.network_os = "ios"
+ connection = self._shared_loader_obj.connection_loader.get(
+ "persistent", pc, sys.stdin, task_uuid=self._task._uuid
+ )
+
+ display.vvv(
+ "using connection plugin %s (was local)" % pc.connection,
+ pc.remote_addr,
+ )
+
+ command_timeout = (
+ int(provider["timeout"])
+ if provider["timeout"]
+ else connection.get_option("persistent_command_timeout")
+ )
+ connection.set_options(
+ direct={"persistent_command_timeout": command_timeout}
+ )
+
+ socket_path = connection.run()
+ display.vvvv("socket_path: %s" % socket_path, pc.remote_addr)
+ if not socket_path:
+ return {
+ "failed": True,
+ "msg": "unable to open shell. Please see: "
+ + "https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell",
+ }
+
+ task_vars["ansible_socket"] = socket_path
+ warnings.append(
+ [
+ "connection local support for this module is deprecated and will be removed in version 2.14, use connection %s"
+ % pc.connection
+ ]
+ )
+ else:
+ return {
+ "failed": True,
+ "msg": "Connection type %s is not valid for this module"
+ % self._play_context.connection,
+ }
+
+ result = super(ActionModule, self).run(task_vars=task_vars)
+ if warnings:
+ if "warnings" in result:
+ result["warnings"].extend(warnings)
+ else:
+ result["warnings"] = warnings
+ return result
diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py
new file mode 100644
index 00000000..8a390034
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py
@@ -0,0 +1,465 @@
+#
+# (c) 2017 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+DOCUMENTATION = """
+---
+author: Ansible Networking Team
+cliconf: ios
+short_description: Use ios cliconf to run command on Cisco IOS platform
+description:
+ - This ios plugin provides low level abstraction apis for
+ sending and receiving CLI commands from Cisco IOS network devices.
+version_added: "2.4"
+"""
+
+import re
+import time
+import json
+
+from ansible.errors import AnsibleConnectionFailure
+from ansible.module_utils._text import to_text
+from ansible.module_utils.common._collections_compat import Mapping
+from ansible.module_utils.six import iteritems
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.config import (
+ NetworkConfig,
+ dumps,
+)
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ to_list,
+)
+from ansible.plugins.cliconf import CliconfBase, enable_mode
+
+
+class Cliconf(CliconfBase):
+ @enable_mode
+ def get_config(self, source="running", flags=None, format=None):
+ if source not in ("running", "startup"):
+ raise ValueError(
+ "fetching configuration from %s is not supported" % source
+ )
+
+ if format:
+ raise ValueError(
+ "'format' value %s is not supported for get_config" % format
+ )
+
+ if not flags:
+ flags = []
+ if source == "running":
+ cmd = "show running-config "
+ else:
+ cmd = "show startup-config "
+
+ cmd += " ".join(to_list(flags))
+ cmd = cmd.strip()
+
+ return self.send_command(cmd)
+
+ def get_diff(
+ self,
+ candidate=None,
+ running=None,
+ diff_match="line",
+ diff_ignore_lines=None,
+ path=None,
+ diff_replace="line",
+ ):
+ """
+ Generate diff between candidate and running configuration. If the
+ remote host supports onbox diff capabilities ie. supports_onbox_diff in that case
+ candidate and running configurations are not required to be passed as argument.
+ In case if onbox diff capability is not supported candidate argument is mandatory
+ and running argument is optional.
+ :param candidate: The configuration which is expected to be present on remote host.
+ :param running: The base configuration which is used to generate diff.
+ :param diff_match: Instructs how to match the candidate configuration with current device configuration
+ Valid values are 'line', 'strict', 'exact', 'none'.
+ 'line' - commands are matched line by line
+ 'strict' - command lines are matched with respect to position
+ 'exact' - command lines must be an equal match
+ 'none' - will not compare the candidate configuration with the running configuration
+ :param diff_ignore_lines: Use this argument to specify one or more lines that should be
+ ignored during the diff. This is used for lines in the configuration
+ that are automatically updated by the system. This argument takes
+ a list of regular expressions or exact line matches.
+ :param path: The ordered set of parents that uniquely identify the section or hierarchy
+ the commands should be checked against. If the parents argument
+ is omitted, the commands are checked against the set of top
+ level or global commands.
+ :param diff_replace: Instructs on the way to perform the configuration on the device.
+ If the replace argument is set to I(line) then the modified lines are
+ pushed to the device in configuration mode. If the replace argument is
+ set to I(block) then the entire command block is pushed to the device in
+ configuration mode if any line is not correct.
+ :return: Configuration diff in json format.
+ {
+ 'config_diff': '',
+ 'banner_diff': {}
+ }
+
+ """
+ diff = {}
+ device_operations = self.get_device_operations()
+ option_values = self.get_option_values()
+
+ if candidate is None and device_operations["supports_generate_diff"]:
+ raise ValueError(
+ "candidate configuration is required to generate diff"
+ )
+
+ if diff_match not in option_values["diff_match"]:
+ raise ValueError(
+ "'match' value %s in invalid, valid values are %s"
+ % (diff_match, ", ".join(option_values["diff_match"]))
+ )
+
+ if diff_replace not in option_values["diff_replace"]:
+ raise ValueError(
+ "'replace' value %s in invalid, valid values are %s"
+ % (diff_replace, ", ".join(option_values["diff_replace"]))
+ )
+
+ # prepare candidate configuration
+ candidate_obj = NetworkConfig(indent=1)
+ want_src, want_banners = self._extract_banners(candidate)
+ candidate_obj.load(want_src)
+
+ if running and diff_match != "none":
+ # running configuration
+ have_src, have_banners = self._extract_banners(running)
+ running_obj = NetworkConfig(
+ indent=1, contents=have_src, ignore_lines=diff_ignore_lines
+ )
+ configdiffobjs = candidate_obj.difference(
+ running_obj, path=path, match=diff_match, replace=diff_replace
+ )
+
+ else:
+ configdiffobjs = candidate_obj.items
+ have_banners = {}
+
+ diff["config_diff"] = (
+ dumps(configdiffobjs, "commands") if configdiffobjs else ""
+ )
+ banners = self._diff_banners(want_banners, have_banners)
+ diff["banner_diff"] = banners if banners else {}
+ return diff
+
+ @enable_mode
+ def edit_config(
+ self, candidate=None, commit=True, replace=None, comment=None
+ ):
+ resp = {}
+ operations = self.get_device_operations()
+ self.check_edit_config_capability(
+ operations, candidate, commit, replace, comment
+ )
+
+ results = []
+ requests = []
+ if commit:
+ self.send_command("configure terminal")
+ for line in to_list(candidate):
+ if not isinstance(line, Mapping):
+ line = {"command": line}
+
+ cmd = line["command"]
+ if cmd != "end" and cmd[0] != "!":
+ results.append(self.send_command(**line))
+ requests.append(cmd)
+
+ self.send_command("end")
+ else:
+ raise ValueError("check mode is not supported")
+
+ resp["request"] = requests
+ resp["response"] = results
+ return resp
+
+ def edit_macro(
+ self, candidate=None, commit=True, replace=None, comment=None
+ ):
+ """
+ ios_config:
+ lines: "{{ macro_lines }}"
+ parents: "macro name {{ macro_name }}"
+ after: '@'
+ match: line
+ replace: block
+ """
+ resp = {}
+ operations = self.get_device_operations()
+ self.check_edit_config_capability(
+ operations, candidate, commit, replace, comment
+ )
+
+ results = []
+ requests = []
+ if commit:
+ commands = ""
+ self.send_command("config terminal")
+ time.sleep(0.1)
+ # first item: macro command
+ commands += candidate.pop(0) + "\n"
+ multiline_delimiter = candidate.pop(-1)
+ for line in candidate:
+ commands += " " + line + "\n"
+ commands += multiline_delimiter + "\n"
+ obj = {"command": commands, "sendonly": True}
+ results.append(self.send_command(**obj))
+ requests.append(commands)
+
+ time.sleep(0.1)
+ self.send_command("end", sendonly=True)
+ time.sleep(0.1)
+ results.append(self.send_command("\n"))
+ requests.append("\n")
+
+ resp["request"] = requests
+ resp["response"] = results
+ return resp
+
+ def get(
+ self,
+ command=None,
+ prompt=None,
+ answer=None,
+ sendonly=False,
+ output=None,
+ newline=True,
+ check_all=False,
+ ):
+ if not command:
+ raise ValueError("must provide value of command to execute")
+ if output:
+ raise ValueError(
+ "'output' value %s is not supported for get" % output
+ )
+
+ return self.send_command(
+ command=command,
+ prompt=prompt,
+ answer=answer,
+ sendonly=sendonly,
+ newline=newline,
+ check_all=check_all,
+ )
+
+ def get_device_info(self):
+ device_info = {}
+
+ device_info["network_os"] = "ios"
+ reply = self.get(command="show version")
+ data = to_text(reply, errors="surrogate_or_strict").strip()
+
+ match = re.search(r"Version (\S+)", data)
+ if match:
+ device_info["network_os_version"] = match.group(1).strip(",")
+
+ model_search_strs = [
+ r"^[Cc]isco (.+) \(revision",
+ r"^[Cc]isco (\S+).+bytes of .*memory",
+ ]
+ for item in model_search_strs:
+ match = re.search(item, data, re.M)
+ if match:
+ version = match.group(1).split(" ")
+ device_info["network_os_model"] = version[0]
+ break
+
+ match = re.search(r"^(.+) uptime", data, re.M)
+ if match:
+ device_info["network_os_hostname"] = match.group(1)
+
+ match = re.search(r'image file is "(.+)"', data)
+ if match:
+ device_info["network_os_image"] = match.group(1)
+
+ return device_info
+
+ def get_device_operations(self):
+ return {
+ "supports_diff_replace": True,
+ "supports_commit": False,
+ "supports_rollback": False,
+ "supports_defaults": True,
+ "supports_onbox_diff": False,
+ "supports_commit_comment": False,
+ "supports_multiline_delimiter": True,
+ "supports_diff_match": True,
+ "supports_diff_ignore_lines": True,
+ "supports_generate_diff": True,
+ "supports_replace": False,
+ }
+
+ def get_option_values(self):
+ return {
+ "format": ["text"],
+ "diff_match": ["line", "strict", "exact", "none"],
+ "diff_replace": ["line", "block"],
+ "output": [],
+ }
+
+ def get_capabilities(self):
+ result = super(Cliconf, self).get_capabilities()
+ result["rpc"] += [
+ "edit_banner",
+ "get_diff",
+ "run_commands",
+ "get_defaults_flag",
+ ]
+ result["device_operations"] = self.get_device_operations()
+ result.update(self.get_option_values())
+ return json.dumps(result)
+
+ def edit_banner(
+ self, candidate=None, multiline_delimiter="@", commit=True
+ ):
+ """
+ Edit banner on remote device
+ :param banners: Banners to be loaded in json format
+ :param multiline_delimiter: Line delimiter for banner
+ :param commit: Boolean value that indicates if the device candidate
+ configuration should be pushed in the running configuration or discarded.
+ :param diff: Boolean flag to indicate if configuration that is applied on remote host should
+ generated and returned in response or not
+ :return: Returns response of executing the configuration command received
+ from remote host
+ """
+ resp = {}
+ banners_obj = json.loads(candidate)
+ results = []
+ requests = []
+ if commit:
+ for key, value in iteritems(banners_obj):
+ key += " %s" % multiline_delimiter
+ self.send_command("config terminal", sendonly=True)
+ for cmd in [key, value, multiline_delimiter]:
+ obj = {"command": cmd, "sendonly": True}
+ results.append(self.send_command(**obj))
+ requests.append(cmd)
+
+ self.send_command("end", sendonly=True)
+ time.sleep(0.1)
+ results.append(self.send_command("\n"))
+ requests.append("\n")
+
+ resp["request"] = requests
+ resp["response"] = results
+
+ return resp
+
+ def run_commands(self, commands=None, check_rc=True):
+ if commands is None:
+ raise ValueError("'commands' value is required")
+
+ responses = list()
+ for cmd in to_list(commands):
+ if not isinstance(cmd, Mapping):
+ cmd = {"command": cmd}
+
+ output = cmd.pop("output", None)
+ if output:
+ raise ValueError(
+ "'output' value %s is not supported for run_commands"
+ % output
+ )
+
+ try:
+ out = self.send_command(**cmd)
+ except AnsibleConnectionFailure as e:
+ if check_rc:
+ raise
+ out = getattr(e, "err", to_text(e))
+
+ responses.append(out)
+
+ return responses
+
+ def get_defaults_flag(self):
+ """
+ The method identifies the filter that should be used to fetch running-configuration
+ with defaults.
+ :return: valid default filter
+ """
+ out = self.get("show running-config ?")
+ out = to_text(out, errors="surrogate_then_replace")
+
+ commands = set()
+ for line in out.splitlines():
+ if line.strip():
+ commands.add(line.strip().split()[0])
+
+ if "all" in commands:
+ return "all"
+ else:
+ return "full"
+
+ def set_cli_prompt_context(self):
+ """
+ Make sure we are in the operational cli mode
+ :return: None
+ """
+ if self._connection.connected:
+ out = self._connection.get_prompt()
+
+ if out is None:
+ raise AnsibleConnectionFailure(
+ message=u"cli prompt is not identified from the last received"
+ u" response window: %s"
+ % self._connection._last_recv_window
+ )
+
+ if re.search(
+ r"config.*\)#",
+ to_text(out, errors="surrogate_then_replace").strip(),
+ ):
+ self._connection.queue_message(
+ "vvvv", "wrong context, sending end to device"
+ )
+ self._connection.send_command("end")
+
+ def _extract_banners(self, config):
+ banners = {}
+ banner_cmds = re.findall(r"^banner (\w+)", config, re.M)
+ for cmd in banner_cmds:
+ regex = r"banner %s \^C(.+?)(?=\^C)" % cmd
+ match = re.search(regex, config, re.S)
+ if match:
+ key = "banner %s" % cmd
+ banners[key] = match.group(1).strip()
+
+ for cmd in banner_cmds:
+ regex = r"banner %s \^C(.+?)(?=\^C)" % cmd
+ match = re.search(regex, config, re.S)
+ if match:
+ config = config.replace(str(match.group(1)), "")
+
+ config = re.sub(r"banner \w+ \^C\^C", "!! banner removed", config)
+ return config, banners
+
+ def _diff_banners(self, want, have):
+ candidate = {}
+ for key, value in iteritems(want):
+ if value != have.get(key):
+ candidate[key] = value
+ return candidate
diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py
new file mode 100644
index 00000000..ff22d27c
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2015, Peter Sprygada <psprygada@ansible.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+
+class ModuleDocFragment(object):
+
+ # Standard files documentation fragment
+ DOCUMENTATION = r"""options:
+ provider:
+ description:
+ - B(Deprecated)
+ - 'Starting with Ansible 2.5 we recommend using C(connection: network_cli).'
+ - For more information please see the L(IOS Platform Options guide, ../network/user_guide/platform_ios.html).
+ - HORIZONTALLINE
+ - A dict object containing connection details.
+ type: dict
+ suboptions:
+ host:
+ description:
+ - Specifies the DNS host name or address for connecting to the remote device
+ over the specified transport. The value of host is used as the destination
+ address for the transport.
+ type: str
+ required: true
+ port:
+ description:
+ - Specifies the port to use when building the connection to the remote device.
+ type: int
+ default: 22
+ username:
+ description:
+ - Configures the username to use to authenticate the connection to the remote
+ device. This value is used to authenticate the SSH session. If the value
+ is not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME)
+ will be used instead.
+ type: str
+ password:
+ description:
+ - Specifies the password to use to authenticate the connection to the remote
+ device. This value is used to authenticate the SSH session. If the value
+ is not specified in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD)
+ will be used instead.
+ type: str
+ timeout:
+ description:
+ - Specifies the timeout in seconds for communicating with the network device
+ for either connecting or sending commands. If the timeout is exceeded before
+ the operation is completed, the module will error.
+ type: int
+ default: 10
+ ssh_keyfile:
+ description:
+ - Specifies the SSH key to use to authenticate the connection to the remote
+ device. This value is the path to the key used to authenticate the SSH
+ session. If the value is not specified in the task, the value of environment
+ variable C(ANSIBLE_NET_SSH_KEYFILE) will be used instead.
+ type: path
+ authorize:
+ description:
+ - Instructs the module to enter privileged mode on the remote device before
+ sending any commands. If not specified, the device will attempt to execute
+ all commands in non-privileged mode. If the value is not specified in the
+ task, the value of environment variable C(ANSIBLE_NET_AUTHORIZE) will be
+ used instead.
+ type: bool
+ default: false
+ auth_pass:
+ description:
+ - Specifies the password to use if required to enter privileged mode on the
+ remote device. If I(authorize) is false, then this argument does nothing.
+ If the value is not specified in the task, the value of environment variable
+ C(ANSIBLE_NET_AUTH_PASS) will be used instead.
+ type: str
+notes:
+- For more information on using Ansible to manage network devices see the :ref:`Ansible
+ Network Guide <network_guide>`
+- For more information on using Ansible to manage Cisco devices see the `Cisco integration
+ page <https://www.ansible.com/integrations/networks/cisco>`_.
+"""
diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py
new file mode 100644
index 00000000..6818a0ce
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py
@@ -0,0 +1,197 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# (c) 2016 Red Hat Inc.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+import json
+
+from ansible.module_utils._text import to_text
+from ansible.module_utils.basic import env_fallback
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ to_list,
+)
+from ansible.module_utils.connection import Connection, ConnectionError
+
+_DEVICE_CONFIGS = {}
+
+ios_provider_spec = {
+ "host": dict(),
+ "port": dict(type="int"),
+ "username": dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])),
+ "password": dict(
+ fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"]), no_log=True
+ ),
+ "ssh_keyfile": dict(
+ fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path"
+ ),
+ "authorize": dict(
+ fallback=(env_fallback, ["ANSIBLE_NET_AUTHORIZE"]), type="bool"
+ ),
+ "auth_pass": dict(
+ fallback=(env_fallback, ["ANSIBLE_NET_AUTH_PASS"]), no_log=True
+ ),
+ "timeout": dict(type="int"),
+}
+ios_argument_spec = {
+ "provider": dict(
+ type="dict", options=ios_provider_spec, removed_in_version=2.14
+ )
+}
+
+
+def get_provider_argspec():
+ return ios_provider_spec
+
+
+def get_connection(module):
+ if hasattr(module, "_ios_connection"):
+ return module._ios_connection
+
+ capabilities = get_capabilities(module)
+ network_api = capabilities.get("network_api")
+ if network_api == "cliconf":
+ module._ios_connection = Connection(module._socket_path)
+ else:
+ module.fail_json(msg="Invalid connection type %s" % network_api)
+
+ return module._ios_connection
+
+
+def get_capabilities(module):
+ if hasattr(module, "_ios_capabilities"):
+ return module._ios_capabilities
+ try:
+ capabilities = Connection(module._socket_path).get_capabilities()
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc, errors="surrogate_then_replace"))
+ module._ios_capabilities = json.loads(capabilities)
+ return module._ios_capabilities
+
+
+def get_defaults_flag(module):
+ connection = get_connection(module)
+ try:
+ out = connection.get_defaults_flag()
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc, errors="surrogate_then_replace"))
+ return to_text(out, errors="surrogate_then_replace").strip()
+
+
+def get_config(module, flags=None):
+ flags = to_list(flags)
+
+ section_filter = False
+ if flags and "section" in flags[-1]:
+ section_filter = True
+
+ flag_str = " ".join(flags)
+
+ try:
+ return _DEVICE_CONFIGS[flag_str]
+ except KeyError:
+ connection = get_connection(module)
+ try:
+ out = connection.get_config(flags=flags)
+ except ConnectionError as exc:
+ if section_filter:
+ # Some ios devices don't understand `| section foo`
+ out = get_config(module, flags=flags[:-1])
+ else:
+ module.fail_json(
+ msg=to_text(exc, errors="surrogate_then_replace")
+ )
+ cfg = to_text(out, errors="surrogate_then_replace").strip()
+ _DEVICE_CONFIGS[flag_str] = cfg
+ return cfg
+
+
+def run_commands(module, commands, check_rc=True):
+ connection = get_connection(module)
+ try:
+ return connection.run_commands(commands=commands, check_rc=check_rc)
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc))
+
+
+def load_config(module, commands):
+ connection = get_connection(module)
+
+ try:
+ resp = connection.edit_config(commands)
+ return resp.get("response")
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc))
+
+
+def normalize_interface(name):
+ """Return the normalized interface name
+ """
+ if not name:
+ return
+
+ def _get_number(name):
+ digits = ""
+ for char in name:
+ if char.isdigit() or char in "/.":
+ digits += char
+ return digits
+
+ if name.lower().startswith("gi"):
+ if_type = "GigabitEthernet"
+ elif name.lower().startswith("te"):
+ if_type = "TenGigabitEthernet"
+ elif name.lower().startswith("fa"):
+ if_type = "FastEthernet"
+ elif name.lower().startswith("fo"):
+ if_type = "FortyGigabitEthernet"
+ elif name.lower().startswith("et"):
+ if_type = "Ethernet"
+ elif name.lower().startswith("vl"):
+ if_type = "Vlan"
+ elif name.lower().startswith("lo"):
+ if_type = "loopback"
+ elif name.lower().startswith("po"):
+ if_type = "port-channel"
+ elif name.lower().startswith("nv"):
+ if_type = "nve"
+ elif name.lower().startswith("twe"):
+ if_type = "TwentyFiveGigE"
+ elif name.lower().startswith("hu"):
+ if_type = "HundredGigE"
+ else:
+ if_type = None
+
+ number_list = name.split(" ")
+ if len(number_list) == 2:
+ if_number = number_list[-1].strip()
+ else:
+ if_number = _get_number(name)
+
+ if if_type:
+ proper_interface = if_type + if_number
+ else:
+ proper_interface = name
+
+ return proper_interface
diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py
new file mode 100644
index 00000000..ef383fcc
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py
@@ -0,0 +1,229 @@
+#!/usr/bin/python
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+
+ANSIBLE_METADATA = {
+ "metadata_version": "1.1",
+ "status": ["preview"],
+ "supported_by": "network",
+}
+
+
+DOCUMENTATION = """module: ios_command
+author: Peter Sprygada (@privateip)
+short_description: Run commands on remote devices running Cisco IOS
+description:
+- Sends arbitrary commands to an ios node and returns the results read from the device.
+ This module includes an argument that will cause the module to wait for a specific
+ condition before returning or timing out if the condition is not met.
+- This module does not support running commands in configuration mode. Please use
+ M(ios_config) to configure IOS devices.
+extends_documentation_fragment:
+- cisco.ios.ios
+notes:
+- Tested against IOS 15.6
+options:
+ commands:
+ description:
+ - List of commands to send to the remote ios device over the configured provider.
+ The resulting output from the command is returned. If the I(wait_for) argument
+ is provided, the module is not returned until the condition is satisfied or
+ the number of retries has expired. If a command sent to the device requires
+ answering a prompt, it is possible to pass a dict containing I(command), I(answer)
+ and I(prompt). Common answers are 'y' or "\r" (carriage return, must be double
+ quotes). See examples.
+ required: true
+ wait_for:
+ description:
+ - List of conditions to evaluate against the output of the command. The task will
+ wait for each condition to be true before moving forward. If the conditional
+ is not true within the configured number of retries, the task fails. See examples.
+ aliases:
+ - waitfor
+ match:
+ description:
+ - The I(match) argument is used in conjunction with the I(wait_for) argument to
+ specify the match policy. Valid values are C(all) or C(any). If the value
+ is set to C(all) then all conditionals in the wait_for must be satisfied. If
+ the value is set to C(any) then only one of the values must be satisfied.
+ default: all
+ choices:
+ - any
+ - all
+ retries:
+ description:
+ - Specifies the number of retries a command should by tried before it is considered
+ failed. The command is run on the target device every retry and evaluated against
+ the I(wait_for) conditions.
+ default: 10
+ interval:
+ description:
+ - Configures the interval in seconds to wait between retries of the command. If
+ the command does not pass the specified conditions, the interval indicates how
+ long to wait before trying the command again.
+ default: 1
+"""
+
+EXAMPLES = r"""
+tasks:
+ - name: run show version on remote devices
+ ios_command:
+ commands: show version
+
+ - name: run show version and check to see if output contains IOS
+ ios_command:
+ commands: show version
+ wait_for: result[0] contains IOS
+
+ - name: run multiple commands on remote nodes
+ ios_command:
+ commands:
+ - show version
+ - show interfaces
+
+ - name: run multiple commands and evaluate the output
+ ios_command:
+ commands:
+ - show version
+ - show interfaces
+ wait_for:
+ - result[0] contains IOS
+ - result[1] contains Loopback0
+
+ - name: run commands that require answering a prompt
+ ios_command:
+ commands:
+ - command: 'clear counters GigabitEthernet0/1'
+ prompt: 'Clear "show interface" counters on this interface \[confirm\]'
+ answer: 'y'
+ - command: 'clear counters GigabitEthernet0/2'
+ prompt: '[confirm]'
+ answer: "\r"
+"""
+
+RETURN = """
+stdout:
+ description: The set of responses from the commands
+ returned: always apart from low level errors (such as action plugin)
+ type: list
+ sample: ['...', '...']
+stdout_lines:
+ description: The value of stdout split into a list
+ returned: always apart from low level errors (such as action plugin)
+ type: list
+ sample: [['...', '...'], ['...'], ['...']]
+failed_conditions:
+ description: The list of conditionals that have failed
+ returned: failed
+ type: list
+ sample: ['...', '...']
+"""
+import time
+
+from ansible.module_utils._text import to_text
+from ansible.module_utils.basic import AnsibleModule
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import (
+ Conditional,
+)
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ transform_commands,
+ to_lines,
+)
+from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import (
+ run_commands,
+)
+from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import (
+ ios_argument_spec,
+)
+
+
+def parse_commands(module, warnings):
+ commands = transform_commands(module)
+
+ if module.check_mode:
+ for item in list(commands):
+ if not item["command"].startswith("show"):
+ warnings.append(
+ "Only show commands are supported when using check mode, not "
+ "executing %s" % item["command"]
+ )
+ commands.remove(item)
+
+ return commands
+
+
+def main():
+ """main entry point for module execution
+ """
+ argument_spec = dict(
+ commands=dict(type="list", required=True),
+ wait_for=dict(type="list", aliases=["waitfor"]),
+ match=dict(default="all", choices=["all", "any"]),
+ retries=dict(default=10, type="int"),
+ interval=dict(default=1, type="int"),
+ )
+
+ argument_spec.update(ios_argument_spec)
+
+ module = AnsibleModule(
+ argument_spec=argument_spec, supports_check_mode=True
+ )
+
+ warnings = list()
+ result = {"changed": False, "warnings": warnings}
+ commands = parse_commands(module, warnings)
+ wait_for = module.params["wait_for"] or list()
+
+ try:
+ conditionals = [Conditional(c) for c in wait_for]
+ except AttributeError as exc:
+ module.fail_json(msg=to_text(exc))
+
+ retries = module.params["retries"]
+ interval = module.params["interval"]
+ match = module.params["match"]
+
+ while retries > 0:
+ responses = run_commands(module, commands)
+
+ for item in list(conditionals):
+ if item(responses):
+ if match == "any":
+ conditionals = list()
+ break
+ conditionals.remove(item)
+
+ if not conditionals:
+ break
+
+ time.sleep(interval)
+ retries -= 1
+
+ if conditionals:
+ failed_conditions = [item.raw for item in conditionals]
+ msg = "One or more conditional statements have not been satisfied"
+ module.fail_json(msg=msg, failed_conditions=failed_conditions)
+
+ result.update(
+ {"stdout": responses, "stdout_lines": list(to_lines(responses))}
+ )
+
+ module.exit_json(**result)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py
new file mode 100644
index 00000000..beec5b8d
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py
@@ -0,0 +1,596 @@
+#!/usr/bin/python
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+
+ANSIBLE_METADATA = {
+ "metadata_version": "1.1",
+ "status": ["preview"],
+ "supported_by": "network",
+}
+
+
+DOCUMENTATION = """module: ios_config
+author: Peter Sprygada (@privateip)
+short_description: Manage Cisco IOS configuration sections
+description:
+- Cisco IOS configurations use a simple block indent file syntax for segmenting configuration
+ into sections. This module provides an implementation for working with IOS configuration
+ sections in a deterministic way.
+extends_documentation_fragment:
+- cisco.ios.ios
+notes:
+- Tested against IOS 15.6
+- Abbreviated commands are NOT idempotent, see L(Network FAQ,../network/user_guide/faq.html#why-do-the-config-modules-always-return-changed-true-with-abbreviated-commands).
+options:
+ lines:
+ description:
+ - The ordered set of commands that should be configured in the section. The commands
+ must be the exact same commands as found in the device running-config. Be sure
+ to note the configuration command syntax as some commands are automatically
+ modified by the device config parser.
+ aliases:
+ - commands
+ parents:
+ description:
+ - The ordered set of parents that uniquely identify the section or hierarchy the
+ commands should be checked against. If the parents argument is omitted, the
+ commands are checked against the set of top level or global commands.
+ src:
+ description:
+ - Specifies the source path to the file that contains the configuration or configuration
+ template to load. The path to the source file can either be the full path on
+ the Ansible control host or a relative path from the playbook or role root directory. This
+ argument is mutually exclusive with I(lines), I(parents).
+ before:
+ description:
+ - The ordered set of commands to push on to the command stack if a change needs
+ to be made. This allows the playbook designer the opportunity to perform configuration
+ commands prior to pushing any changes without affecting how the set of commands
+ are matched against the system.
+ after:
+ description:
+ - The ordered set of commands to append to the end of the command stack if a change
+ needs to be made. Just like with I(before) this allows the playbook designer
+ to append a set of commands to be executed after the command set.
+ match:
+ description:
+ - Instructs the module on the way to perform the matching of the set of commands
+ against the current device config. If match is set to I(line), commands are
+ matched line by line. If match is set to I(strict), command lines are matched
+ with respect to position. If match is set to I(exact), command lines must be
+ an equal match. Finally, if match is set to I(none), the module will not attempt
+ to compare the source configuration with the running configuration on the remote
+ device.
+ choices:
+ - line
+ - strict
+ - exact
+ - none
+ default: line
+ replace:
+ description:
+ - Instructs the module on the way to perform the configuration on the device.
+ If the replace argument is set to I(line) then the modified lines are pushed
+ to the device in configuration mode. If the replace argument is set to I(block)
+ then the entire command block is pushed to the device in configuration mode
+ if any line is not correct.
+ default: line
+ choices:
+ - line
+ - block
+ multiline_delimiter:
+ description:
+ - This argument is used when pushing a multiline configuration element to the
+ IOS device. It specifies the character to use as the delimiting character. This
+ only applies to the configuration action.
+ default: '@'
+ backup:
+ description:
+ - This argument will cause the module to create a full backup of the current C(running-config)
+ from the remote device before any changes are made. If the C(backup_options)
+ value is not given, the backup file is written to the C(backup) folder in the
+ playbook root directory or role root directory, if playbook is part of an ansible
+ role. If the directory does not exist, it is created.
+ type: bool
+ default: 'no'
+ running_config:
+ description:
+ - The module, by default, will connect to the remote device and retrieve the current
+ running-config to use as a base for comparing against the contents of source.
+ There are times when it is not desirable to have the task get the current running-config
+ for every task in a playbook. The I(running_config) argument allows the implementer
+ to pass in the configuration to use as the base config for comparison.
+ aliases:
+ - config
+ defaults:
+ description:
+ - This argument specifies whether or not to collect all defaults when getting
+ the remote device running config. When enabled, the module will get the current
+ config by issuing the command C(show running-config all).
+ type: bool
+ default: 'no'
+ save_when:
+ description:
+ - When changes are made to the device running-configuration, the changes are not
+ copied to non-volatile storage by default. Using this argument will change
+ that before. If the argument is set to I(always), then the running-config will
+ always be copied to the startup-config and the I(modified) flag will always
+ be set to True. If the argument is set to I(modified), then the running-config
+ will only be copied to the startup-config if it has changed since the last save
+ to startup-config. If the argument is set to I(never), the running-config will
+ never be copied to the startup-config. If the argument is set to I(changed),
+ then the running-config will only be copied to the startup-config if the task
+ has made a change. I(changed) was added in Ansible 2.5.
+ default: never
+ choices:
+ - always
+ - never
+ - modified
+ - changed
+ diff_against:
+ description:
+ - When using the C(ansible-playbook --diff) command line argument the module can
+ generate diffs against different sources.
+ - When this option is configure as I(startup), the module will return the diff
+ of the running-config against the startup-config.
+ - When this option is configured as I(intended), the module will return the diff
+ of the running-config against the configuration provided in the C(intended_config)
+ argument.
+ - When this option is configured as I(running), the module will return the before
+ and after diff of the running-config with respect to any changes made to the
+ device configuration.
+ choices:
+ - running
+ - startup
+ - intended
+ diff_ignore_lines:
+ description:
+ - Use this argument to specify one or more lines that should be ignored during
+ the diff. This is used for lines in the configuration that are automatically
+ updated by the system. This argument takes a list of regular expressions or
+ exact line matches.
+ intended_config:
+ description:
+ - The C(intended_config) provides the master configuration that the node should
+ conform to and is used to check the final running-config against. This argument
+ will not modify any settings on the remote device and is strictly used to check
+ the compliance of the current device's configuration against. When specifying
+ this argument, the task should also modify the C(diff_against) value and set
+ it to I(intended).
+ backup_options:
+ description:
+ - This is a dict object containing configurable options related to backup file
+ path. The value of this option is read only when C(backup) is set to I(yes),
+ if C(backup) is set to I(no) this option will be silently ignored.
+ suboptions:
+ filename:
+ description:
+ - The filename to be used to store the backup configuration. If the filename
+ is not given it will be generated based on the hostname, current time and
+ date in format defined by <hostname>_config.<current-date>@<current-time>
+ dir_path:
+ description:
+ - This option provides the path ending with directory name in which the backup
+ configuration file will be stored. If the directory does not exist it will
+ be first created and the filename is either the value of C(filename) or
+ default filename as described in C(filename) options description. If the
+ path value is not given in that case a I(backup) directory will be created
+ in the current working directory and backup configuration will be copied
+ in C(filename) within I(backup) directory.
+ type: path
+ type: dict
+"""
+
+EXAMPLES = """
+- name: configure top level configuration
+ ios_config:
+ lines: hostname {{ inventory_hostname }}
+
+- name: configure interface settings
+ ios_config:
+ lines:
+ - description test interface
+ - ip address 172.31.1.1 255.255.255.0
+ parents: interface Ethernet1
+
+- name: configure ip helpers on multiple interfaces
+ ios_config:
+ lines:
+ - ip helper-address 172.26.1.10
+ - ip helper-address 172.26.3.8
+ parents: "{{ item }}"
+ with_items:
+ - interface Ethernet1
+ - interface Ethernet2
+ - interface GigabitEthernet1
+
+- name: configure policer in Scavenger class
+ ios_config:
+ lines:
+ - conform-action transmit
+ - exceed-action drop
+ parents:
+ - policy-map Foo
+ - class Scavenger
+ - police cir 64000
+
+- name: load new acl into device
+ ios_config:
+ lines:
+ - 10 permit ip host 192.0.2.1 any log
+ - 20 permit ip host 192.0.2.2 any log
+ - 30 permit ip host 192.0.2.3 any log
+ - 40 permit ip host 192.0.2.4 any log
+ - 50 permit ip host 192.0.2.5 any log
+ parents: ip access-list extended test
+ before: no ip access-list extended test
+ match: exact
+
+- name: check the running-config against master config
+ ios_config:
+ diff_against: intended
+ intended_config: "{{ lookup('file', 'master.cfg') }}"
+
+- name: check the startup-config against the running-config
+ ios_config:
+ diff_against: startup
+ diff_ignore_lines:
+ - ntp clock .*
+
+- name: save running to startup when modified
+ ios_config:
+ save_when: modified
+
+- name: for idempotency, use full-form commands
+ ios_config:
+ lines:
+ # - shut
+ - shutdown
+ # parents: int gig1/0/11
+ parents: interface GigabitEthernet1/0/11
+
+# Set boot image based on comparison to a group_var (version) and the version
+# that is returned from the `ios_facts` module
+- name: SETTING BOOT IMAGE
+ ios_config:
+ lines:
+ - no boot system
+ - boot system flash bootflash:{{new_image}}
+ host: "{{ inventory_hostname }}"
+ when: ansible_net_version != version
+
+- name: render a Jinja2 template onto an IOS device
+ ios_config:
+ backup: yes
+ src: ios_template.j2
+
+- name: configurable backup path
+ ios_config:
+ src: ios_template.j2
+ backup: yes
+ backup_options:
+ filename: backup.cfg
+ dir_path: /home/user
+"""
+
+RETURN = """
+updates:
+ description: The set of commands that will be pushed to the remote device
+ returned: always
+ type: list
+ sample: ['hostname foo', 'router ospf 1', 'router-id 192.0.2.1']
+commands:
+ description: The set of commands that will be pushed to the remote device
+ returned: always
+ type: list
+ sample: ['hostname foo', 'router ospf 1', 'router-id 192.0.2.1']
+backup_path:
+ description: The full path to the backup file
+ returned: when backup is yes
+ type: str
+ sample: /playbooks/ansible/backup/ios_config.2016-07-16@22:28:34
+filename:
+ description: The name of the backup file
+ returned: when backup is yes and filename is not specified in backup options
+ type: str
+ sample: ios_config.2016-07-16@22:28:34
+shortname:
+ description: The full path to the backup file excluding the timestamp
+ returned: when backup is yes and filename is not specified in backup options
+ type: str
+ sample: /playbooks/ansible/backup/ios_config
+date:
+ description: The date extracted from the backup file name
+ returned: when backup is yes
+ type: str
+ sample: "2016-07-16"
+time:
+ description: The time extracted from the backup file name
+ returned: when backup is yes
+ type: str
+ sample: "22:28:34"
+"""
+import json
+
+from ansible.module_utils._text import to_text
+from ansible.module_utils.connection import ConnectionError
+from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import (
+ run_commands,
+ get_config,
+)
+from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import (
+ get_defaults_flag,
+ get_connection,
+)
+from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import (
+ ios_argument_spec,
+)
+from ansible.module_utils.basic import AnsibleModule
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.config import (
+ NetworkConfig,
+ dumps,
+)
+
+
+def check_args(module, warnings):
+ if module.params["multiline_delimiter"]:
+ if len(module.params["multiline_delimiter"]) != 1:
+ module.fail_json(
+ msg="multiline_delimiter value can only be a "
+ "single character"
+ )
+
+
+def edit_config_or_macro(connection, commands):
+ # only catch the macro configuration command,
+ # not negated 'no' variation.
+ if commands[0].startswith("macro name"):
+ connection.edit_macro(candidate=commands)
+ else:
+ connection.edit_config(candidate=commands)
+
+
+def get_candidate_config(module):
+ candidate = ""
+ if module.params["src"]:
+ candidate = module.params["src"]
+
+ elif module.params["lines"]:
+ candidate_obj = NetworkConfig(indent=1)
+ parents = module.params["parents"] or list()
+ candidate_obj.add(module.params["lines"], parents=parents)
+ candidate = dumps(candidate_obj, "raw")
+
+ return candidate
+
+
+def get_running_config(module, current_config=None, flags=None):
+ running = module.params["running_config"]
+ if not running:
+ if not module.params["defaults"] and current_config:
+ running = current_config
+ else:
+ running = get_config(module, flags=flags)
+
+ return running
+
+
+def save_config(module, result):
+ result["changed"] = True
+ if not module.check_mode:
+ run_commands(module, "copy running-config startup-config\r")
+ else:
+ module.warn(
+ "Skipping command `copy running-config startup-config` "
+ "due to check_mode. Configuration not copied to "
+ "non-volatile storage"
+ )
+
+
+def main():
+ """ main entry point for module execution
+ """
+ backup_spec = dict(filename=dict(), dir_path=dict(type="path"))
+ argument_spec = dict(
+ src=dict(type="path"),
+ lines=dict(aliases=["commands"], type="list"),
+ parents=dict(type="list"),
+ before=dict(type="list"),
+ after=dict(type="list"),
+ match=dict(
+ default="line", choices=["line", "strict", "exact", "none"]
+ ),
+ replace=dict(default="line", choices=["line", "block"]),
+ multiline_delimiter=dict(default="@"),
+ running_config=dict(aliases=["config"]),
+ intended_config=dict(),
+ defaults=dict(type="bool", default=False),
+ backup=dict(type="bool", default=False),
+ backup_options=dict(type="dict", options=backup_spec),
+ save_when=dict(
+ choices=["always", "never", "modified", "changed"], default="never"
+ ),
+ diff_against=dict(choices=["startup", "intended", "running"]),
+ diff_ignore_lines=dict(type="list"),
+ )
+
+ argument_spec.update(ios_argument_spec)
+
+ mutually_exclusive = [("lines", "src"), ("parents", "src")]
+
+ required_if = [
+ ("match", "strict", ["lines"]),
+ ("match", "exact", ["lines"]),
+ ("replace", "block", ["lines"]),
+ ("diff_against", "intended", ["intended_config"]),
+ ]
+
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ mutually_exclusive=mutually_exclusive,
+ required_if=required_if,
+ supports_check_mode=True,
+ )
+
+ result = {"changed": False}
+
+ warnings = list()
+ check_args(module, warnings)
+ result["warnings"] = warnings
+
+ diff_ignore_lines = module.params["diff_ignore_lines"]
+ config = None
+ contents = None
+ flags = get_defaults_flag(module) if module.params["defaults"] else []
+ connection = get_connection(module)
+
+ if module.params["backup"] or (
+ module._diff and module.params["diff_against"] == "running"
+ ):
+ contents = get_config(module, flags=flags)
+ config = NetworkConfig(indent=1, contents=contents)
+ if module.params["backup"]:
+ result["__backup__"] = contents
+
+ if any((module.params["lines"], module.params["src"])):
+ match = module.params["match"]
+ replace = module.params["replace"]
+ path = module.params["parents"]
+
+ candidate = get_candidate_config(module)
+ running = get_running_config(module, contents, flags=flags)
+ try:
+ response = connection.get_diff(
+ candidate=candidate,
+ running=running,
+ diff_match=match,
+ diff_ignore_lines=diff_ignore_lines,
+ path=path,
+ diff_replace=replace,
+ )
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc, errors="surrogate_then_replace"))
+
+ config_diff = response["config_diff"]
+ banner_diff = response["banner_diff"]
+
+ if config_diff or banner_diff:
+ commands = config_diff.split("\n")
+
+ if module.params["before"]:
+ commands[:0] = module.params["before"]
+
+ if module.params["after"]:
+ commands.extend(module.params["after"])
+
+ result["commands"] = commands
+ result["updates"] = commands
+ result["banners"] = banner_diff
+
+ # send the configuration commands to the device and merge
+ # them with the current running config
+ if not module.check_mode:
+ if commands:
+ edit_config_or_macro(connection, commands)
+ if banner_diff:
+ connection.edit_banner(
+ candidate=json.dumps(banner_diff),
+ multiline_delimiter=module.params[
+ "multiline_delimiter"
+ ],
+ )
+
+ result["changed"] = True
+
+ running_config = module.params["running_config"]
+ startup_config = None
+
+ if module.params["save_when"] == "always":
+ save_config(module, result)
+ elif module.params["save_when"] == "modified":
+ output = run_commands(
+ module, ["show running-config", "show startup-config"]
+ )
+
+ running_config = NetworkConfig(
+ indent=1, contents=output[0], ignore_lines=diff_ignore_lines
+ )
+ startup_config = NetworkConfig(
+ indent=1, contents=output[1], ignore_lines=diff_ignore_lines
+ )
+
+ if running_config.sha1 != startup_config.sha1:
+ save_config(module, result)
+ elif module.params["save_when"] == "changed" and result["changed"]:
+ save_config(module, result)
+
+ if module._diff:
+ if not running_config:
+ output = run_commands(module, "show running-config")
+ contents = output[0]
+ else:
+ contents = running_config
+
+ # recreate the object in order to process diff_ignore_lines
+ running_config = NetworkConfig(
+ indent=1, contents=contents, ignore_lines=diff_ignore_lines
+ )
+
+ if module.params["diff_against"] == "running":
+ if module.check_mode:
+ module.warn(
+ "unable to perform diff against running-config due to check mode"
+ )
+ contents = None
+ else:
+ contents = config.config_text
+
+ elif module.params["diff_against"] == "startup":
+ if not startup_config:
+ output = run_commands(module, "show startup-config")
+ contents = output[0]
+ else:
+ contents = startup_config.config_text
+
+ elif module.params["diff_against"] == "intended":
+ contents = module.params["intended_config"]
+
+ if contents is not None:
+ base_config = NetworkConfig(
+ indent=1, contents=contents, ignore_lines=diff_ignore_lines
+ )
+
+ if running_config.sha1 != base_config.sha1:
+ if module.params["diff_against"] == "intended":
+ before = running_config
+ after = base_config
+ elif module.params["diff_against"] in ("startup", "running"):
+ before = base_config
+ after = running_config
+
+ result.update(
+ {
+ "changed": True,
+ "diff": {"before": str(before), "after": str(after)},
+ }
+ )
+
+ module.exit_json(**result)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py
new file mode 100644
index 00000000..29f31b0e
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py
@@ -0,0 +1,115 @@
+#
+# (c) 2016 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import json
+import re
+
+from ansible.errors import AnsibleConnectionFailure
+from ansible.module_utils._text import to_text, to_bytes
+from ansible.plugins.terminal import TerminalBase
+from ansible.utils.display import Display
+
+display = Display()
+
+
+class TerminalModule(TerminalBase):
+
+ terminal_stdout_re = [
+ re.compile(br"[\r\n]?[\w\+\-\.:\/\[\]]+(?:\([^\)]+\)){0,3}(?:[>#]) ?$")
+ ]
+
+ terminal_stderr_re = [
+ re.compile(br"% ?Error"),
+ # re.compile(br"^% \w+", re.M),
+ re.compile(br"% ?Bad secret"),
+ re.compile(br"[\r\n%] Bad passwords"),
+ re.compile(br"invalid input", re.I),
+ re.compile(br"(?:incomplete|ambiguous) command", re.I),
+ re.compile(br"connection timed out", re.I),
+ re.compile(br"[^\r\n]+ not found"),
+ re.compile(br"'[^']' +returned error code: ?\d+"),
+ re.compile(br"Bad mask", re.I),
+ re.compile(br"% ?(\S+) ?overlaps with ?(\S+)", re.I),
+ re.compile(br"[%\S] ?Error: ?[\s]+", re.I),
+ re.compile(br"[%\S] ?Informational: ?[\s]+", re.I),
+ re.compile(br"Command authorization failed"),
+ ]
+
+ def on_open_shell(self):
+ try:
+ self._exec_cli_command(b"terminal length 0")
+ except AnsibleConnectionFailure:
+ raise AnsibleConnectionFailure("unable to set terminal parameters")
+
+ try:
+ self._exec_cli_command(b"terminal width 512")
+ try:
+ self._exec_cli_command(b"terminal width 0")
+ except AnsibleConnectionFailure:
+ pass
+ except AnsibleConnectionFailure:
+ display.display(
+ "WARNING: Unable to set terminal width, command responses may be truncated"
+ )
+
+ def on_become(self, passwd=None):
+ if self._get_prompt().endswith(b"#"):
+ return
+
+ cmd = {u"command": u"enable"}
+ if passwd:
+ # Note: python-3.5 cannot combine u"" and r"" together. Thus make
+ # an r string and use to_text to ensure it's text on both py2 and py3.
+ cmd[u"prompt"] = to_text(
+ r"[\r\n]?(?:.*)?[Pp]assword: ?$", errors="surrogate_or_strict"
+ )
+ cmd[u"answer"] = passwd
+ cmd[u"prompt_retry_check"] = True
+ try:
+ self._exec_cli_command(
+ to_bytes(json.dumps(cmd), errors="surrogate_or_strict")
+ )
+ prompt = self._get_prompt()
+ if prompt is None or not prompt.endswith(b"#"):
+ raise AnsibleConnectionFailure(
+ "failed to elevate privilege to enable mode still at prompt [%s]"
+ % prompt
+ )
+ except AnsibleConnectionFailure as e:
+ prompt = self._get_prompt()
+ raise AnsibleConnectionFailure(
+ "unable to elevate privilege to enable mode, at prompt [%s] with error: %s"
+ % (prompt, e.message)
+ )
+
+ def on_unbecome(self):
+ prompt = self._get_prompt()
+ if prompt is None:
+ # if prompt is None most likely the terminal is hung up at a prompt
+ return
+
+ if b"(config" in prompt:
+ self._exec_cli_command(b"end")
+ self._exec_cli_command(b"disable")
+
+ elif prompt.endswith(b"#"):
+ self._exec_cli_command(b"disable")
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py
new file mode 100644
index 00000000..cab2f3fd
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py
@@ -0,0 +1,129 @@
+#
+# (c) 2016 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import sys
+import copy
+
+from ansible_collections.ansible.netcommon.plugins.action.network import (
+ ActionModule as ActionNetworkModule,
+)
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ load_provider,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import (
+ vyos_provider_spec,
+)
+from ansible.utils.display import Display
+
+display = Display()
+
+
+class ActionModule(ActionNetworkModule):
+ def run(self, tmp=None, task_vars=None):
+ del tmp # tmp no longer has any effect
+
+ module_name = self._task.action.split(".")[-1]
+ self._config_module = True if module_name == "vyos_config" else False
+ persistent_connection = self._play_context.connection.split(".")[-1]
+ warnings = []
+
+ if persistent_connection == "network_cli":
+ provider = self._task.args.get("provider", {})
+ if any(provider.values()):
+ display.warning(
+ "provider is unnecessary when using network_cli and will be ignored"
+ )
+ del self._task.args["provider"]
+ elif self._play_context.connection == "local":
+ provider = load_provider(vyos_provider_spec, self._task.args)
+ pc = copy.deepcopy(self._play_context)
+ pc.connection = "ansible.netcommon.network_cli"
+ pc.network_os = "vyos.vyos.vyos"
+ pc.remote_addr = provider["host"] or self._play_context.remote_addr
+ pc.port = int(provider["port"] or self._play_context.port or 22)
+ pc.remote_user = (
+ provider["username"] or self._play_context.connection_user
+ )
+ pc.password = provider["password"] or self._play_context.password
+ pc.private_key_file = (
+ provider["ssh_keyfile"] or self._play_context.private_key_file
+ )
+
+ connection = self._shared_loader_obj.connection_loader.get(
+ "ansible.netcommon.persistent",
+ pc,
+ sys.stdin,
+ task_uuid=self._task._uuid,
+ )
+
+ # TODO: Remove below code after ansible minimal is cut out
+ if connection is None:
+ pc.connection = "network_cli"
+ pc.network_os = "vyos"
+ connection = self._shared_loader_obj.connection_loader.get(
+ "persistent", pc, sys.stdin, task_uuid=self._task._uuid
+ )
+
+ display.vvv(
+ "using connection plugin %s (was local)" % pc.connection,
+ pc.remote_addr,
+ )
+
+ command_timeout = (
+ int(provider["timeout"])
+ if provider["timeout"]
+ else connection.get_option("persistent_command_timeout")
+ )
+ connection.set_options(
+ direct={"persistent_command_timeout": command_timeout}
+ )
+
+ socket_path = connection.run()
+ display.vvvv("socket_path: %s" % socket_path, pc.remote_addr)
+ if not socket_path:
+ return {
+ "failed": True,
+ "msg": "unable to open shell. Please see: "
+ + "https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell",
+ }
+
+ task_vars["ansible_socket"] = socket_path
+ warnings.append(
+ [
+ "connection local support for this module is deprecated and will be removed in version 2.14, use connection %s"
+ % pc.connection
+ ]
+ )
+ else:
+ return {
+ "failed": True,
+ "msg": "Connection type %s is not valid for this module"
+ % self._play_context.connection,
+ }
+
+ result = super(ActionModule, self).run(task_vars=task_vars)
+ if warnings:
+ if "warnings" in result:
+ result["warnings"].extend(warnings)
+ else:
+ result["warnings"] = warnings
+ return result
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py
new file mode 100644
index 00000000..30336031
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py
@@ -0,0 +1,342 @@
+#
+# (c) 2017 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+DOCUMENTATION = """
+---
+author: Ansible Networking Team
+cliconf: vyos
+short_description: Use vyos cliconf to run command on VyOS platform
+description:
+ - This vyos plugin provides low level abstraction apis for
+ sending and receiving CLI commands from VyOS network devices.
+version_added: "2.4"
+"""
+
+import re
+import json
+
+from ansible.errors import AnsibleConnectionFailure
+from ansible.module_utils._text import to_text
+from ansible.module_utils.common._collections_compat import Mapping
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.config import (
+ NetworkConfig,
+)
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ to_list,
+)
+from ansible.plugins.cliconf import CliconfBase
+
+
+class Cliconf(CliconfBase):
+ def get_device_info(self):
+ device_info = {}
+
+ device_info["network_os"] = "vyos"
+ reply = self.get("show version")
+ data = to_text(reply, errors="surrogate_or_strict").strip()
+
+ match = re.search(r"Version:\s*(.*)", data)
+ if match:
+ device_info["network_os_version"] = match.group(1)
+
+ match = re.search(r"HW model:\s*(\S+)", data)
+ if match:
+ device_info["network_os_model"] = match.group(1)
+
+ reply = self.get("show host name")
+ device_info["network_os_hostname"] = to_text(
+ reply, errors="surrogate_or_strict"
+ ).strip()
+
+ return device_info
+
+ def get_config(self, flags=None, format=None):
+ if format:
+ option_values = self.get_option_values()
+ if format not in option_values["format"]:
+ raise ValueError(
+ "'format' value %s is invalid. Valid values of format are %s"
+ % (format, ", ".join(option_values["format"]))
+ )
+
+ if not flags:
+ flags = []
+
+ if format == "text":
+ command = "show configuration"
+ else:
+ command = "show configuration commands"
+
+ command += " ".join(to_list(flags))
+ command = command.strip()
+
+ out = self.send_command(command)
+ return out
+
+ def edit_config(
+ self, candidate=None, commit=True, replace=None, comment=None
+ ):
+ resp = {}
+ operations = self.get_device_operations()
+ self.check_edit_config_capability(
+ operations, candidate, commit, replace, comment
+ )
+
+ results = []
+ requests = []
+ self.send_command("configure")
+ for cmd in to_list(candidate):
+ if not isinstance(cmd, Mapping):
+ cmd = {"command": cmd}
+
+ results.append(self.send_command(**cmd))
+ requests.append(cmd["command"])
+ out = self.get("compare")
+ out = to_text(out, errors="surrogate_or_strict")
+ diff_config = out if not out.startswith("No changes") else None
+
+ if diff_config:
+ if commit:
+ try:
+ self.commit(comment)
+ except AnsibleConnectionFailure as e:
+ msg = "commit failed: %s" % e.message
+ self.discard_changes()
+ raise AnsibleConnectionFailure(msg)
+ else:
+ self.send_command("exit")
+ else:
+ self.discard_changes()
+ else:
+ self.send_command("exit")
+ if (
+ to_text(
+ self._connection.get_prompt(), errors="surrogate_or_strict"
+ )
+ .strip()
+ .endswith("#")
+ ):
+ self.discard_changes()
+
+ if diff_config:
+ resp["diff"] = diff_config
+ resp["response"] = results
+ resp["request"] = requests
+ return resp
+
+ def get(
+ self,
+ command=None,
+ prompt=None,
+ answer=None,
+ sendonly=False,
+ output=None,
+ newline=True,
+ check_all=False,
+ ):
+ if not command:
+ raise ValueError("must provide value of command to execute")
+ if output:
+ raise ValueError(
+ "'output' value %s is not supported for get" % output
+ )
+
+ return self.send_command(
+ command=command,
+ prompt=prompt,
+ answer=answer,
+ sendonly=sendonly,
+ newline=newline,
+ check_all=check_all,
+ )
+
+ def commit(self, comment=None):
+ if comment:
+ command = 'commit comment "{0}"'.format(comment)
+ else:
+ command = "commit"
+ self.send_command(command)
+
+ def discard_changes(self):
+ self.send_command("exit discard")
+
+ def get_diff(
+ self,
+ candidate=None,
+ running=None,
+ diff_match="line",
+ diff_ignore_lines=None,
+ path=None,
+ diff_replace=None,
+ ):
+ diff = {}
+ device_operations = self.get_device_operations()
+ option_values = self.get_option_values()
+
+ if candidate is None and device_operations["supports_generate_diff"]:
+ raise ValueError(
+ "candidate configuration is required to generate diff"
+ )
+
+ if diff_match not in option_values["diff_match"]:
+ raise ValueError(
+ "'match' value %s in invalid, valid values are %s"
+ % (diff_match, ", ".join(option_values["diff_match"]))
+ )
+
+ if diff_replace:
+ raise ValueError("'replace' in diff is not supported")
+
+ if diff_ignore_lines:
+ raise ValueError("'diff_ignore_lines' in diff is not supported")
+
+ if path:
+ raise ValueError("'path' in diff is not supported")
+
+ set_format = candidate.startswith("set") or candidate.startswith(
+ "delete"
+ )
+ candidate_obj = NetworkConfig(indent=4, contents=candidate)
+ if not set_format:
+ config = [c.line for c in candidate_obj.items]
+ commands = list()
+ # this filters out less specific lines
+ for item in config:
+ for index, entry in enumerate(commands):
+ if item.startswith(entry):
+ del commands[index]
+ break
+ commands.append(item)
+
+ candidate_commands = [
+ "set %s" % cmd.replace(" {", "") for cmd in commands
+ ]
+
+ else:
+ candidate_commands = str(candidate).strip().split("\n")
+
+ if diff_match == "none":
+ diff["config_diff"] = list(candidate_commands)
+ return diff
+
+ running_commands = [
+ str(c).replace("'", "") for c in running.splitlines()
+ ]
+
+ updates = list()
+ visited = set()
+
+ for line in candidate_commands:
+ item = str(line).replace("'", "")
+
+ if not item.startswith("set") and not item.startswith("delete"):
+ raise ValueError(
+ "line must start with either `set` or `delete`"
+ )
+
+ elif item.startswith("set") and item not in running_commands:
+ updates.append(line)
+
+ elif item.startswith("delete"):
+ if not running_commands:
+ updates.append(line)
+ else:
+ item = re.sub(r"delete", "set", item)
+ for entry in running_commands:
+ if entry.startswith(item) and line not in visited:
+ updates.append(line)
+ visited.add(line)
+
+ diff["config_diff"] = list(updates)
+ return diff
+
+ def run_commands(self, commands=None, check_rc=True):
+ if commands is None:
+ raise ValueError("'commands' value is required")
+
+ responses = list()
+ for cmd in to_list(commands):
+ if not isinstance(cmd, Mapping):
+ cmd = {"command": cmd}
+
+ output = cmd.pop("output", None)
+ if output:
+ raise ValueError(
+ "'output' value %s is not supported for run_commands"
+ % output
+ )
+
+ try:
+ out = self.send_command(**cmd)
+ except AnsibleConnectionFailure as e:
+ if check_rc:
+ raise
+ out = getattr(e, "err", e)
+
+ responses.append(out)
+
+ return responses
+
+ def get_device_operations(self):
+ return {
+ "supports_diff_replace": False,
+ "supports_commit": True,
+ "supports_rollback": False,
+ "supports_defaults": False,
+ "supports_onbox_diff": True,
+ "supports_commit_comment": True,
+ "supports_multiline_delimiter": False,
+ "supports_diff_match": True,
+ "supports_diff_ignore_lines": False,
+ "supports_generate_diff": False,
+ "supports_replace": False,
+ }
+
+ def get_option_values(self):
+ return {
+ "format": ["text", "set"],
+ "diff_match": ["line", "none"],
+ "diff_replace": [],
+ "output": [],
+ }
+
+ def get_capabilities(self):
+ result = super(Cliconf, self).get_capabilities()
+ result["rpc"] += [
+ "commit",
+ "discard_changes",
+ "get_diff",
+ "run_commands",
+ ]
+ result["device_operations"] = self.get_device_operations()
+ result.update(self.get_option_values())
+ return json.dumps(result)
+
+ def set_cli_prompt_context(self):
+ """
+ Make sure we are in the operational cli mode
+ :return: None
+ """
+ if self._connection.connected:
+ self._update_cli_prompt_context(
+ config_context="#", exit_command="exit discard"
+ )
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py
new file mode 100644
index 00000000..094963f1
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py
@@ -0,0 +1,63 @@
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2015, Peter Sprygada <psprygada@ansible.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+
+class ModuleDocFragment(object):
+
+ # Standard files documentation fragment
+ DOCUMENTATION = r"""options:
+ provider:
+ description:
+ - B(Deprecated)
+ - 'Starting with Ansible 2.5 we recommend using C(connection: network_cli).'
+ - For more information please see the L(Network Guide, ../network/getting_started/network_differences.html#multiple-communication-protocols).
+ - HORIZONTALLINE
+ - A dict object containing connection details.
+ type: dict
+ suboptions:
+ host:
+ description:
+ - Specifies the DNS host name or address for connecting to the remote device
+ over the specified transport. The value of host is used as the destination
+ address for the transport.
+ type: str
+ required: true
+ port:
+ description:
+ - Specifies the port to use when building the connection to the remote device.
+ type: int
+ default: 22
+ username:
+ description:
+ - Configures the username to use to authenticate the connection to the remote
+ device. This value is used to authenticate the SSH session. If the value
+ is not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME)
+ will be used instead.
+ type: str
+ password:
+ description:
+ - Specifies the password to use to authenticate the connection to the remote
+ device. This value is used to authenticate the SSH session. If the value
+ is not specified in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD)
+ will be used instead.
+ type: str
+ timeout:
+ description:
+ - Specifies the timeout in seconds for communicating with the network device
+ for either connecting or sending commands. If the timeout is exceeded before
+ the operation is completed, the module will error.
+ type: int
+ default: 10
+ ssh_keyfile:
+ description:
+ - Specifies the SSH key to use to authenticate the connection to the remote
+ device. This value is the path to the key used to authenticate the SSH
+ session. If the value is not specified in the task, the value of environment
+ variable C(ANSIBLE_NET_SSH_KEYFILE) will be used instead.
+ type: path
+notes:
+- For more information on using Ansible to manage network devices see the :ref:`Ansible
+ Network Guide <network_guide>`
+"""
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py
new file mode 100644
index 00000000..46fabaa2
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py
@@ -0,0 +1,22 @@
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The arg spec for the vyos facts module.
+"""
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+class FactsArgs(object): # pylint: disable=R0903
+ """ The arg spec for the vyos facts module
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ argument_spec = {
+ "gather_subset": dict(default=["!config"], type="list"),
+ "gather_network_resources": dict(type="list"),
+ }
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py
new file mode 100644
index 00000000..a018cc0b
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py
@@ -0,0 +1,263 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#############################################
+# WARNING #
+#############################################
+#
+# This file is auto generated by the resource
+# module builder playbook.
+#
+# Do not edit this file manually.
+#
+# Changes to this file will be over written
+# by the resource module builder.
+#
+# Changes should be made in the model used to
+# generate this file or in the resource module
+# builder template.
+#
+#############################################
+"""
+The arg spec for the vyos_firewall_rules module
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+class Firewall_rulesArgs(object): # pylint: disable=R0903
+ """The arg spec for the vyos_firewall_rules module
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ argument_spec = {
+ "config": {
+ "elements": "dict",
+ "options": {
+ "afi": {
+ "choices": ["ipv4", "ipv6"],
+ "required": True,
+ "type": "str",
+ },
+ "rule_sets": {
+ "elements": "dict",
+ "options": {
+ "default_action": {
+ "choices": ["drop", "reject", "accept"],
+ "type": "str",
+ },
+ "description": {"type": "str"},
+ "enable_default_log": {"type": "bool"},
+ "name": {"type": "str"},
+ "rules": {
+ "elements": "dict",
+ "options": {
+ "action": {
+ "choices": [
+ "drop",
+ "reject",
+ "accept",
+ "inspect",
+ ],
+ "type": "str",
+ },
+ "description": {"type": "str"},
+ "destination": {
+ "options": {
+ "address": {"type": "str"},
+ "group": {
+ "options": {
+ "address_group": {
+ "type": "str"
+ },
+ "network_group": {
+ "type": "str"
+ },
+ "port_group": {"type": "str"},
+ },
+ "type": "dict",
+ },
+ "port": {"type": "str"},
+ },
+ "type": "dict",
+ },
+ "disabled": {"type": "bool"},
+ "fragment": {
+ "choices": [
+ "match-frag",
+ "match-non-frag",
+ ],
+ "type": "str",
+ },
+ "icmp": {
+ "options": {
+ "code": {"type": "int"},
+ "type": {"type": "int"},
+ "type_name": {
+ "choices": [
+ "any",
+ "echo-reply",
+ "destination-unreachable",
+ "network-unreachable",
+ "host-unreachable",
+ "protocol-unreachable",
+ "port-unreachable",
+ "fragmentation-needed",
+ "source-route-failed",
+ "network-unknown",
+ "host-unknown",
+ "network-prohibited",
+ "host-prohibited",
+ "TOS-network-unreachable",
+ "TOS-host-unreachable",
+ "communication-prohibited",
+ "host-precedence-violation",
+ "precedence-cutoff",
+ "source-quench",
+ "redirect",
+ "network-redirect",
+ "host-redirect",
+ "TOS-network-redirect",
+ "TOS-host-redirect",
+ "echo-request",
+ "router-advertisement",
+ "router-solicitation",
+ "time-exceeded",
+ "ttl-zero-during-transit",
+ "ttl-zero-during-reassembly",
+ "parameter-problem",
+ "ip-header-bad",
+ "required-option-missing",
+ "timestamp-request",
+ "timestamp-reply",
+ "address-mask-request",
+ "address-mask-reply",
+ "ping",
+ "pong",
+ "ttl-exceeded",
+ ],
+ "type": "str",
+ },
+ },
+ "type": "dict",
+ },
+ "ipsec": {
+ "choices": ["match-ipsec", "match-none"],
+ "type": "str",
+ },
+ "limit": {
+ "options": {
+ "burst": {"type": "int"},
+ "rate": {
+ "options": {
+ "number": {"type": "int"},
+ "unit": {"type": "str"},
+ },
+ "type": "dict",
+ },
+ },
+ "type": "dict",
+ },
+ "number": {"required": True, "type": "int"},
+ "p2p": {
+ "elements": "dict",
+ "options": {
+ "application": {
+ "choices": [
+ "all",
+ "applejuice",
+ "bittorrent",
+ "directconnect",
+ "edonkey",
+ "gnutella",
+ "kazaa",
+ ],
+ "type": "str",
+ }
+ },
+ "type": "list",
+ },
+ "protocol": {"type": "str"},
+ "recent": {
+ "options": {
+ "count": {"type": "int"},
+ "time": {"type": "int"},
+ },
+ "type": "dict",
+ },
+ "source": {
+ "options": {
+ "address": {"type": "str"},
+ "group": {
+ "options": {
+ "address_group": {
+ "type": "str"
+ },
+ "network_group": {
+ "type": "str"
+ },
+ "port_group": {"type": "str"},
+ },
+ "type": "dict",
+ },
+ "mac_address": {"type": "str"},
+ "port": {"type": "str"},
+ },
+ "type": "dict",
+ },
+ "state": {
+ "options": {
+ "established": {"type": "bool"},
+ "invalid": {"type": "bool"},
+ "new": {"type": "bool"},
+ "related": {"type": "bool"},
+ },
+ "type": "dict",
+ },
+ "tcp": {
+ "options": {"flags": {"type": "str"}},
+ "type": "dict",
+ },
+ "time": {
+ "options": {
+ "monthdays": {"type": "str"},
+ "startdate": {"type": "str"},
+ "starttime": {"type": "str"},
+ "stopdate": {"type": "str"},
+ "stoptime": {"type": "str"},
+ "utc": {"type": "bool"},
+ "weekdays": {"type": "str"},
+ },
+ "type": "dict",
+ },
+ },
+ "type": "list",
+ },
+ },
+ "type": "list",
+ },
+ },
+ "type": "list",
+ },
+ "running_config": {"type": "str"},
+ "state": {
+ "choices": [
+ "merged",
+ "replaced",
+ "overridden",
+ "deleted",
+ "gathered",
+ "rendered",
+ "parsed",
+ ],
+ "default": "merged",
+ "type": "str",
+ },
+ } # pylint: disable=C0301
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py
new file mode 100644
index 00000000..3542cb19
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py
@@ -0,0 +1,69 @@
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#############################################
+# WARNING #
+#############################################
+#
+# This file is auto generated by the resource
+# module builder playbook.
+#
+# Do not edit this file manually.
+#
+# Changes to this file will be over written
+# by the resource module builder.
+#
+# Changes should be made in the model used to
+# generate this file or in the resource module
+# builder template.
+#
+#############################################
+"""
+The arg spec for the vyos_interfaces module
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+class InterfacesArgs(object): # pylint: disable=R0903
+ """The arg spec for the vyos_interfaces module
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ argument_spec = {
+ "config": {
+ "elements": "dict",
+ "options": {
+ "description": {"type": "str"},
+ "duplex": {"choices": ["full", "half", "auto"]},
+ "enabled": {"default": True, "type": "bool"},
+ "mtu": {"type": "int"},
+ "name": {"required": True, "type": "str"},
+ "speed": {
+ "choices": ["auto", "10", "100", "1000", "2500", "10000"],
+ "type": "str",
+ },
+ "vifs": {
+ "elements": "dict",
+ "options": {
+ "vlan_id": {"type": "int"},
+ "description": {"type": "str"},
+ "enabled": {"default": True, "type": "bool"},
+ "mtu": {"type": "int"},
+ },
+ "type": "list",
+ },
+ },
+ "type": "list",
+ },
+ "state": {
+ "choices": ["merged", "replaced", "overridden", "deleted"],
+ "default": "merged",
+ "type": "str",
+ },
+ } # pylint: disable=C0301
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py
new file mode 100644
index 00000000..91434e4b
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py
@@ -0,0 +1,81 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#############################################
+# WARNING #
+#############################################
+#
+# This file is auto generated by the resource
+# module builder playbook.
+#
+# Do not edit this file manually.
+#
+# Changes to this file will be over written
+# by the resource module builder.
+#
+# Changes should be made in the model used to
+# generate this file or in the resource module
+# builder template.
+#
+#############################################
+"""
+The arg spec for the vyos_l3_interfaces module
+"""
+
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+class L3_interfacesArgs(object): # pylint: disable=R0903
+ """The arg spec for the vyos_l3_interfaces module
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ argument_spec = {
+ "config": {
+ "elements": "dict",
+ "options": {
+ "ipv4": {
+ "elements": "dict",
+ "options": {"address": {"type": "str"}},
+ "type": "list",
+ },
+ "ipv6": {
+ "elements": "dict",
+ "options": {"address": {"type": "str"}},
+ "type": "list",
+ },
+ "name": {"required": True, "type": "str"},
+ "vifs": {
+ "elements": "dict",
+ "options": {
+ "ipv4": {
+ "elements": "dict",
+ "options": {"address": {"type": "str"}},
+ "type": "list",
+ },
+ "ipv6": {
+ "elements": "dict",
+ "options": {"address": {"type": "str"}},
+ "type": "list",
+ },
+ "vlan_id": {"type": "int"},
+ },
+ "type": "list",
+ },
+ },
+ "type": "list",
+ },
+ "state": {
+ "choices": ["merged", "replaced", "overridden", "deleted"],
+ "default": "merged",
+ "type": "str",
+ },
+ } # pylint: disable=C0301
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py
new file mode 100644
index 00000000..97c5d5a2
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py
@@ -0,0 +1,80 @@
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#############################################
+# WARNING #
+#############################################
+#
+# This file is auto generated by the resource
+# module builder playbook.
+#
+# Do not edit this file manually.
+#
+# Changes to this file will be over written
+# by the resource module builder.
+#
+# Changes should be made in the model used to
+# generate this file or in the resource module
+# builder template.
+#
+#############################################
+
+"""
+The arg spec for the vyos_lag_interfaces module
+"""
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+class Lag_interfacesArgs(object): # pylint: disable=R0903
+ """The arg spec for the vyos_lag_interfaces module
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ argument_spec = {
+ "config": {
+ "elements": "dict",
+ "options": {
+ "arp_monitor": {
+ "options": {
+ "interval": {"type": "int"},
+ "target": {"type": "list"},
+ },
+ "type": "dict",
+ },
+ "hash_policy": {
+ "choices": ["layer2", "layer2+3", "layer3+4"],
+ "type": "str",
+ },
+ "members": {
+ "elements": "dict",
+ "options": {"member": {"type": "str"}},
+ "type": "list",
+ },
+ "mode": {
+ "choices": [
+ "802.3ad",
+ "active-backup",
+ "broadcast",
+ "round-robin",
+ "transmit-load-balance",
+ "adaptive-load-balance",
+ "xor-hash",
+ ],
+ "type": "str",
+ },
+ "name": {"required": True, "type": "str"},
+ "primary": {"type": "str"},
+ },
+ "type": "list",
+ },
+ "state": {
+ "choices": ["merged", "replaced", "overridden", "deleted"],
+ "default": "merged",
+ "type": "str",
+ },
+ } # pylint: disable=C0301
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py
new file mode 100644
index 00000000..84bbc00c
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py
@@ -0,0 +1,56 @@
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#############################################
+# WARNING #
+#############################################
+#
+# This file is auto generated by the resource
+# module builder playbook.
+#
+# Do not edit this file manually.
+#
+# Changes to this file will be over written
+# by the resource module builder.
+#
+# Changes should be made in the model used to
+# generate this file or in the resource module
+# builder template.
+#
+#############################################
+
+"""
+The arg spec for the vyos_lldp_global module
+"""
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+class Lldp_globalArgs(object): # pylint: disable=R0903
+ """The arg spec for the vyos_lldp_global module
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ argument_spec = {
+ "config": {
+ "options": {
+ "address": {"type": "str"},
+ "enable": {"type": "bool"},
+ "legacy_protocols": {
+ "choices": ["cdp", "edp", "fdp", "sonmp"],
+ "type": "list",
+ },
+ "snmp": {"type": "str"},
+ },
+ "type": "dict",
+ },
+ "state": {
+ "choices": ["merged", "replaced", "deleted"],
+ "default": "merged",
+ "type": "str",
+ },
+ } # pylint: disable=C0301
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py
new file mode 100644
index 00000000..2976fc09
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py
@@ -0,0 +1,89 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#############################################
+# WARNING #
+#############################################
+#
+# This file is auto generated by the resource
+# module builder playbook.
+#
+# Do not edit this file manually.
+#
+# Changes to this file will be over written
+# by the resource module builder.
+#
+# Changes should be made in the model used to
+# generate this file or in the resource module
+# builder template.
+#
+#############################################
+"""
+The arg spec for the vyos_lldp_interfaces module
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+class Lldp_interfacesArgs(object): # pylint: disable=R0903
+ """The arg spec for the vyos_lldp_interfaces module
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ argument_spec = {
+ "config": {
+ "elements": "dict",
+ "options": {
+ "enable": {"default": True, "type": "bool"},
+ "location": {
+ "options": {
+ "civic_based": {
+ "options": {
+ "ca_info": {
+ "elements": "dict",
+ "options": {
+ "ca_type": {"type": "int"},
+ "ca_value": {"type": "str"},
+ },
+ "type": "list",
+ },
+ "country_code": {
+ "required": True,
+ "type": "str",
+ },
+ },
+ "type": "dict",
+ },
+ "coordinate_based": {
+ "options": {
+ "altitude": {"type": "int"},
+ "datum": {
+ "choices": ["WGS84", "NAD83", "MLLW"],
+ "type": "str",
+ },
+ "latitude": {"required": True, "type": "str"},
+ "longitude": {"required": True, "type": "str"},
+ },
+ "type": "dict",
+ },
+ "elin": {"type": "str"},
+ },
+ "type": "dict",
+ },
+ "name": {"required": True, "type": "str"},
+ },
+ "type": "list",
+ },
+ "state": {
+ "choices": ["merged", "replaced", "overridden", "deleted"],
+ "default": "merged",
+ "type": "str",
+ },
+ } # pylint: disable=C0301
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py
new file mode 100644
index 00000000..8ecd955a
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py
@@ -0,0 +1,99 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#############################################
+# WARNING #
+#############################################
+#
+# This file is auto generated by the resource
+# module builder playbook.
+#
+# Do not edit this file manually.
+#
+# Changes to this file will be over written
+# by the resource module builder.
+#
+# Changes should be made in the model used to
+# generate this file or in the resource module
+# builder template.
+#
+#############################################
+"""
+The arg spec for the vyos_static_routes module
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+class Static_routesArgs(object): # pylint: disable=R0903
+ """The arg spec for the vyos_static_routes module
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ argument_spec = {
+ "config": {
+ "elements": "dict",
+ "options": {
+ "address_families": {
+ "elements": "dict",
+ "options": {
+ "afi": {
+ "choices": ["ipv4", "ipv6"],
+ "required": True,
+ "type": "str",
+ },
+ "routes": {
+ "elements": "dict",
+ "options": {
+ "blackhole_config": {
+ "options": {
+ "distance": {"type": "int"},
+ "type": {"type": "str"},
+ },
+ "type": "dict",
+ },
+ "dest": {"required": True, "type": "str"},
+ "next_hops": {
+ "elements": "dict",
+ "options": {
+ "admin_distance": {"type": "int"},
+ "enabled": {"type": "bool"},
+ "forward_router_address": {
+ "required": True,
+ "type": "str",
+ },
+ "interface": {"type": "str"},
+ },
+ "type": "list",
+ },
+ },
+ "type": "list",
+ },
+ },
+ "type": "list",
+ }
+ },
+ "type": "list",
+ },
+ "running_config": {"type": "str"},
+ "state": {
+ "choices": [
+ "merged",
+ "replaced",
+ "overridden",
+ "deleted",
+ "gathered",
+ "rendered",
+ "parsed",
+ ],
+ "default": "merged",
+ "type": "str",
+ },
+ } # pylint: disable=C0301
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py
new file mode 100644
index 00000000..377fec9a
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py
@@ -0,0 +1,438 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The vyos_lldp_interfaces class
+It is in this file where the current configuration (as dict)
+is compared to the provided configuration (as dict) and the command set
+necessary to bring the current configuration to it's desired end-state is
+created
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.cfg.base import (
+ ConfigBase,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.facts import (
+ Facts,
+)
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ to_list,
+ dict_diff,
+)
+from ansible.module_utils.six import iteritems
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.utils.utils import (
+ search_obj_in_list,
+ search_dict_tv_in_list,
+ key_value_in_dict,
+ is_dict_element_present,
+)
+
+
+class Lldp_interfaces(ConfigBase):
+ """
+ The vyos_lldp_interfaces class
+ """
+
+ gather_subset = [
+ "!all",
+ "!min",
+ ]
+
+ gather_network_resources = [
+ "lldp_interfaces",
+ ]
+
+ params = ["enable", "location", "name"]
+
+ def __init__(self, module):
+ super(Lldp_interfaces, self).__init__(module)
+
+ def get_lldp_interfaces_facts(self):
+ """ Get the 'facts' (the current configuration)
+
+ :rtype: A dictionary
+ :returns: The current configuration as a dictionary
+ """
+ facts, _warnings = Facts(self._module).get_facts(
+ self.gather_subset, self.gather_network_resources
+ )
+ lldp_interfaces_facts = facts["ansible_network_resources"].get(
+ "lldp_interfaces"
+ )
+ if not lldp_interfaces_facts:
+ return []
+ return lldp_interfaces_facts
+
+ def execute_module(self):
+ """ Execute the module
+
+ :rtype: A dictionary
+ :returns: The result from module execution
+ """
+ result = {"changed": False}
+ commands = list()
+ warnings = list()
+ existing_lldp_interfaces_facts = self.get_lldp_interfaces_facts()
+ commands.extend(self.set_config(existing_lldp_interfaces_facts))
+ if commands:
+ if self._module.check_mode:
+ resp = self._connection.edit_config(commands, commit=False)
+ else:
+ resp = self._connection.edit_config(commands)
+ result["changed"] = True
+
+ result["commands"] = commands
+
+ if self._module._diff:
+ result["diff"] = resp["diff"] if result["changed"] else None
+
+ changed_lldp_interfaces_facts = self.get_lldp_interfaces_facts()
+ result["before"] = existing_lldp_interfaces_facts
+ if result["changed"]:
+ result["after"] = changed_lldp_interfaces_facts
+
+ result["warnings"] = warnings
+ return result
+
+ def set_config(self, existing_lldp_interfaces_facts):
+ """ Collect the configuration from the args passed to the module,
+ collect the current configuration (as a dict from facts)
+
+ :rtype: A list
+ :returns: the commands necessary to migrate the current configuration
+ to the desired configuration
+ """
+ want = self._module.params["config"]
+ have = existing_lldp_interfaces_facts
+ resp = self.set_state(want, have)
+ return to_list(resp)
+
+ def set_state(self, want, have):
+ """ Select the appropriate function based on the state provided
+
+ :param want: the desired configuration as a dictionary
+ :param have: the current configuration as a dictionary
+ :rtype: A list
+ :returns: the commands necessary to migrate the current configuration
+ to the desired configuration
+ """
+ commands = []
+ state = self._module.params["state"]
+ if state in ("merged", "replaced", "overridden") and not want:
+ self._module.fail_json(
+ msg="value of config parameter must not be empty for state {0}".format(
+ state
+ )
+ )
+ if state == "overridden":
+ commands.extend(self._state_overridden(want=want, have=have))
+ elif state == "deleted":
+ if want:
+ for item in want:
+ name = item["name"]
+ have_item = search_obj_in_list(name, have)
+ commands.extend(
+ self._state_deleted(want=None, have=have_item)
+ )
+ else:
+ for have_item in have:
+ commands.extend(
+ self._state_deleted(want=None, have=have_item)
+ )
+ else:
+ for want_item in want:
+ name = want_item["name"]
+ have_item = search_obj_in_list(name, have)
+ if state == "merged":
+ commands.extend(
+ self._state_merged(want=want_item, have=have_item)
+ )
+ else:
+ commands.extend(
+ self._state_replaced(want=want_item, have=have_item)
+ )
+ return commands
+
+ def _state_replaced(self, want, have):
+ """ The command generator when state is replaced
+
+ :rtype: A list
+ :returns: the commands necessary to migrate the current configuration
+ to the desired configuration
+ """
+ commands = []
+ if have:
+ commands.extend(self._state_deleted(want, have))
+ commands.extend(self._state_merged(want, have))
+ return commands
+
+ def _state_overridden(self, want, have):
+ """ The command generator when state is overridden
+
+ :rtype: A list
+ :returns: the commands necessary to migrate the current configuration
+ to the desired configuration
+ """
+ commands = []
+ for have_item in have:
+ lldp_name = have_item["name"]
+ lldp_in_want = search_obj_in_list(lldp_name, want)
+ if not lldp_in_want:
+ commands.append(
+ self._compute_command(have_item["name"], remove=True)
+ )
+
+ for want_item in want:
+ name = want_item["name"]
+ lldp_in_have = search_obj_in_list(name, have)
+ commands.extend(self._state_replaced(want_item, lldp_in_have))
+ return commands
+
+ def _state_merged(self, want, have):
+ """ The command generator when state is merged
+
+ :rtype: A list
+ :returns: the commands necessary to merge the provided into
+ the current configuration
+ """
+ commands = []
+ if have:
+ commands.extend(self._render_updates(want, have))
+ else:
+ commands.extend(self._render_set_commands(want))
+ return commands
+
+ def _state_deleted(self, want, have):
+ """ The command generator when state is deleted
+
+ :rtype: A list
+ :returns: the commands necessary to remove the current configuration
+ of the provided objects
+ """
+ commands = []
+ if want:
+ params = Lldp_interfaces.params
+ for attrib in params:
+ if attrib == "location":
+ commands.extend(
+ self._update_location(have["name"], want, have)
+ )
+
+ elif have:
+ commands.append(self._compute_command(have["name"], remove=True))
+ return commands
+
+ def _render_updates(self, want, have):
+ commands = []
+ lldp_name = have["name"]
+ commands.extend(self._configure_status(lldp_name, want, have))
+ commands.extend(self._add_location(lldp_name, want, have))
+
+ return commands
+
+ def _render_set_commands(self, want):
+ commands = []
+ have = {}
+ lldp_name = want["name"]
+ params = Lldp_interfaces.params
+
+ commands.extend(self._add_location(lldp_name, want, have))
+ for attrib in params:
+ value = want[attrib]
+ if value:
+ if attrib == "location":
+ commands.extend(self._add_location(lldp_name, want, have))
+ elif attrib == "enable":
+ if not value:
+ commands.append(
+ self._compute_command(lldp_name, value="disable")
+ )
+ else:
+ commands.append(self._compute_command(lldp_name))
+
+ return commands
+
+ def _configure_status(self, name, want_item, have_item):
+ commands = []
+ if is_dict_element_present(have_item, "enable"):
+ temp_have_item = False
+ else:
+ temp_have_item = True
+ if want_item["enable"] != temp_have_item:
+ if want_item["enable"]:
+ commands.append(
+ self._compute_command(name, value="disable", remove=True)
+ )
+ else:
+ commands.append(self._compute_command(name, value="disable"))
+ return commands
+
+ def _add_location(self, name, want_item, have_item):
+ commands = []
+ have_dict = {}
+ have_ca = {}
+ set_cmd = name + " location "
+ want_location_type = want_item.get("location") or {}
+ have_location_type = have_item.get("location") or {}
+
+ if want_location_type["coordinate_based"]:
+ want_dict = want_location_type.get("coordinate_based") or {}
+ if is_dict_element_present(have_location_type, "coordinate_based"):
+ have_dict = have_location_type.get("coordinate_based") or {}
+ location_type = "coordinate-based"
+ updates = dict_diff(have_dict, want_dict)
+ for key, value in iteritems(updates):
+ if value:
+ commands.append(
+ self._compute_command(
+ set_cmd + location_type, key, str(value)
+ )
+ )
+
+ elif want_location_type["civic_based"]:
+ location_type = "civic-based"
+ want_dict = want_location_type.get("civic_based") or {}
+ want_ca = want_dict.get("ca_info") or []
+ if is_dict_element_present(have_location_type, "civic_based"):
+ have_dict = have_location_type.get("civic_based") or {}
+ have_ca = have_dict.get("ca_info") or []
+ if want_dict["country_code"] != have_dict["country_code"]:
+ commands.append(
+ self._compute_command(
+ set_cmd + location_type,
+ "country-code",
+ str(want_dict["country_code"]),
+ )
+ )
+ else:
+ commands.append(
+ self._compute_command(
+ set_cmd + location_type,
+ "country-code",
+ str(want_dict["country_code"]),
+ )
+ )
+ commands.extend(self._add_civic_address(name, want_ca, have_ca))
+
+ elif want_location_type["elin"]:
+ location_type = "elin"
+ if is_dict_element_present(have_location_type, "elin"):
+ if want_location_type.get("elin") != have_location_type.get(
+ "elin"
+ ):
+ commands.append(
+ self._compute_command(
+ set_cmd + location_type,
+ value=str(want_location_type["elin"]),
+ )
+ )
+ else:
+ commands.append(
+ self._compute_command(
+ set_cmd + location_type,
+ value=str(want_location_type["elin"]),
+ )
+ )
+ return commands
+
+ def _update_location(self, name, want_item, have_item):
+ commands = []
+ del_cmd = name + " location"
+ want_location_type = want_item.get("location") or {}
+ have_location_type = have_item.get("location") or {}
+
+ if want_location_type["coordinate_based"]:
+ want_dict = want_location_type.get("coordinate_based") or {}
+ if is_dict_element_present(have_location_type, "coordinate_based"):
+ have_dict = have_location_type.get("coordinate_based") or {}
+ location_type = "coordinate-based"
+ for key, value in iteritems(have_dict):
+ only_in_have = key_value_in_dict(key, value, want_dict)
+ if not only_in_have:
+ commands.append(
+ self._compute_command(
+ del_cmd + location_type, key, str(value), True
+ )
+ )
+ else:
+ commands.append(self._compute_command(del_cmd, remove=True))
+
+ elif want_location_type["civic_based"]:
+ want_dict = want_location_type.get("civic_based") or {}
+ want_ca = want_dict.get("ca_info") or []
+ if is_dict_element_present(have_location_type, "civic_based"):
+ have_dict = have_location_type.get("civic_based") or {}
+ have_ca = have_dict.get("ca_info")
+ commands.extend(
+ self._update_civic_address(name, want_ca, have_ca)
+ )
+ else:
+ commands.append(self._compute_command(del_cmd, remove=True))
+
+ else:
+ if is_dict_element_present(have_location_type, "elin"):
+ if want_location_type.get("elin") != have_location_type.get(
+ "elin"
+ ):
+ commands.append(
+ self._compute_command(del_cmd, remove=True)
+ )
+ else:
+ commands.append(self._compute_command(del_cmd, remove=True))
+ return commands
+
+ def _add_civic_address(self, name, want, have):
+ commands = []
+ for item in want:
+ ca_type = item["ca_type"]
+ ca_value = item["ca_value"]
+ obj_in_have = search_dict_tv_in_list(
+ ca_type, ca_value, have, "ca_type", "ca_value"
+ )
+ if not obj_in_have:
+ commands.append(
+ self._compute_command(
+ key=name + " location civic-based ca-type",
+ attrib=str(ca_type) + " ca-value",
+ value=ca_value,
+ )
+ )
+ return commands
+
+ def _update_civic_address(self, name, want, have):
+ commands = []
+ for item in have:
+ ca_type = item["ca_type"]
+ ca_value = item["ca_value"]
+ in_want = search_dict_tv_in_list(
+ ca_type, ca_value, want, "ca_type", "ca_value"
+ )
+ if not in_want:
+ commands.append(
+ self._compute_command(
+ name,
+ "location civic-based ca-type",
+ str(ca_type),
+ remove=True,
+ )
+ )
+ return commands
+
+ def _compute_command(self, key, attrib=None, value=None, remove=False):
+ if remove:
+ cmd = "delete service lldp interface "
+ else:
+ cmd = "set service lldp interface "
+ cmd += key
+ if attrib:
+ cmd += " " + attrib
+ if value:
+ cmd += " '" + value + "'"
+ return cmd
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py
new file mode 100644
index 00000000..8f0a3bb6
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py
@@ -0,0 +1,83 @@
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The facts class for vyos
+this file validates each subset of facts and selectively
+calls the appropriate facts gathering function
+"""
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.facts.facts import (
+ FactsBase,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.interfaces.interfaces import (
+ InterfacesFacts,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.l3_interfaces.l3_interfaces import (
+ L3_interfacesFacts,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.lag_interfaces.lag_interfaces import (
+ Lag_interfacesFacts,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.lldp_global.lldp_global import (
+ Lldp_globalFacts,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.lldp_interfaces.lldp_interfaces import (
+ Lldp_interfacesFacts,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.firewall_rules.firewall_rules import (
+ Firewall_rulesFacts,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.static_routes.static_routes import (
+ Static_routesFacts,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.legacy.base import (
+ Default,
+ Neighbors,
+ Config,
+)
+
+
+FACT_LEGACY_SUBSETS = dict(default=Default, neighbors=Neighbors, config=Config)
+FACT_RESOURCE_SUBSETS = dict(
+ interfaces=InterfacesFacts,
+ l3_interfaces=L3_interfacesFacts,
+ lag_interfaces=Lag_interfacesFacts,
+ lldp_global=Lldp_globalFacts,
+ lldp_interfaces=Lldp_interfacesFacts,
+ static_routes=Static_routesFacts,
+ firewall_rules=Firewall_rulesFacts,
+)
+
+
+class Facts(FactsBase):
+ """ The fact class for vyos
+ """
+
+ VALID_LEGACY_GATHER_SUBSETS = frozenset(FACT_LEGACY_SUBSETS.keys())
+ VALID_RESOURCE_SUBSETS = frozenset(FACT_RESOURCE_SUBSETS.keys())
+
+ def __init__(self, module):
+ super(Facts, self).__init__(module)
+
+ def get_facts(
+ self, legacy_facts_type=None, resource_facts_type=None, data=None
+ ):
+ """ Collect the facts for vyos
+ :param legacy_facts_type: List of legacy facts types
+ :param resource_facts_type: List of resource fact types
+ :param data: previously collected conf
+ :rtype: dict
+ :return: the facts gathered
+ """
+ if self.VALID_RESOURCE_SUBSETS:
+ self.get_network_resources_facts(
+ FACT_RESOURCE_SUBSETS, resource_facts_type, data
+ )
+ if self.VALID_LEGACY_GATHER_SUBSETS:
+ self.get_network_legacy_facts(
+ FACT_LEGACY_SUBSETS, legacy_facts_type
+ )
+ return self.ansible_facts, self._warnings
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py
new file mode 100644
index 00000000..971ea6fe
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py
@@ -0,0 +1,380 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The vyos firewall_rules fact class
+It is in this file the configuration is collected from the device
+for a given resource, parsed, and the facts tree is populated
+based on the configuration.
+"""
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+from re import findall, search, M
+from copy import deepcopy
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import (
+ utils,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.firewall_rules.firewall_rules import (
+ Firewall_rulesArgs,
+)
+
+
+class Firewall_rulesFacts(object):
+ """ The vyos firewall_rules fact class
+ """
+
+ def __init__(self, module, subspec="config", options="options"):
+ self._module = module
+ self.argument_spec = Firewall_rulesArgs.argument_spec
+ spec = deepcopy(self.argument_spec)
+ if subspec:
+ if options:
+ facts_argument_spec = spec[subspec][options]
+ else:
+ facts_argument_spec = spec[subspec]
+ else:
+ facts_argument_spec = spec
+
+ self.generated_spec = utils.generate_dict(facts_argument_spec)
+
+ def get_device_data(self, connection):
+ return connection.get_config()
+
+ def populate_facts(self, connection, ansible_facts, data=None):
+ """ Populate the facts for firewall_rules
+ :param connection: the device connection
+ :param ansible_facts: Facts dictionary
+ :param data: previously collected conf
+ :rtype: dictionary
+ :returns: facts
+ """
+ if not data:
+ # typically data is populated from the current device configuration
+ # data = connection.get('show running-config | section ^interface')
+ # using mock data instead
+ data = self.get_device_data(connection)
+ # split the config into instances of the resource
+ objs = []
+ v6_rules = findall(
+ r"^set firewall ipv6-name (?:\'*)(\S+)(?:\'*)", data, M
+ )
+ v4_rules = findall(r"^set firewall name (?:\'*)(\S+)(?:\'*)", data, M)
+ if v6_rules:
+ config = self.get_rules(data, v6_rules, type="ipv6")
+ if config:
+ config = utils.remove_empties(config)
+ objs.append(config)
+ if v4_rules:
+ config = self.get_rules(data, v4_rules, type="ipv4")
+ if config:
+ config = utils.remove_empties(config)
+ objs.append(config)
+
+ ansible_facts["ansible_network_resources"].pop("firewall_rules", None)
+ facts = {}
+ if objs:
+ facts["firewall_rules"] = []
+ params = utils.validate_config(
+ self.argument_spec, {"config": objs}
+ )
+ for cfg in params["config"]:
+ facts["firewall_rules"].append(utils.remove_empties(cfg))
+
+ ansible_facts["ansible_network_resources"].update(facts)
+ return ansible_facts
+
+ def get_rules(self, data, rules, type):
+ """
+ This function performs following:
+ - Form regex to fetch 'rule-sets' specific config from data.
+ - Form the rule-set list based on ip address.
+ :param data: configuration.
+ :param rules: list of rule-sets.
+ :param type: ip address type.
+ :return: generated rule-sets configuration.
+ """
+ r_v4 = []
+ r_v6 = []
+ for r in set(rules):
+ rule_regex = r" %s .+$" % r.strip("'")
+ cfg = findall(rule_regex, data, M)
+ fr = self.render_config(cfg, r.strip("'"))
+ fr["name"] = r.strip("'")
+ if type == "ipv6":
+ r_v6.append(fr)
+ else:
+ r_v4.append(fr)
+ if r_v4:
+ config = {"afi": "ipv4", "rule_sets": r_v4}
+ if r_v6:
+ config = {"afi": "ipv6", "rule_sets": r_v6}
+ return config
+
+ def render_config(self, conf, match):
+ """
+ Render config as dictionary structure and delete keys
+ from spec for null values
+
+ :param spec: The facts tree, generated from the argspec
+ :param conf: The configuration
+ :rtype: dictionary
+ :returns: The generated config
+ """
+ conf = "\n".join(filter(lambda x: x, conf))
+ a_lst = ["description", "default_action", "enable_default_log"]
+ config = self.parse_attr(conf, a_lst, match)
+ if not config:
+ config = {}
+ config["rules"] = self.parse_rules_lst(conf)
+ return config
+
+ def parse_rules_lst(self, conf):
+ """
+ This function forms the regex to fetch the 'rules' with in
+ 'rule-sets'
+ :param conf: configuration data.
+ :return: generated rule list configuration.
+ """
+ r_lst = []
+ rules = findall(r"rule (?:\'*)(\d+)(?:\'*)", conf, M)
+ if rules:
+ rules_lst = []
+ for r in set(rules):
+ r_regex = r" %s .+$" % r
+ cfg = "\n".join(findall(r_regex, conf, M))
+ obj = self.parse_rules(cfg)
+ obj["number"] = int(r)
+ if obj:
+ rules_lst.append(obj)
+ r_lst = sorted(rules_lst, key=lambda i: i["number"])
+ return r_lst
+
+ def parse_rules(self, conf):
+ """
+ This function triggers the parsing of 'rule' attributes.
+ a_lst is a list having rule attributes which doesn't
+ have further sub attributes.
+ :param conf: configuration
+ :return: generated rule configuration dictionary.
+ """
+ a_lst = [
+ "ipsec",
+ "action",
+ "protocol",
+ "fragment",
+ "disabled",
+ "description",
+ ]
+ rule = self.parse_attr(conf, a_lst)
+ r_sub = {
+ "p2p": self.parse_p2p(conf),
+ "tcp": self.parse_tcp(conf, "tcp"),
+ "icmp": self.parse_icmp(conf, "icmp"),
+ "time": self.parse_time(conf, "time"),
+ "limit": self.parse_limit(conf, "limit"),
+ "state": self.parse_state(conf, "state"),
+ "recent": self.parse_recent(conf, "recent"),
+ "source": self.parse_src_or_dest(conf, "source"),
+ "destination": self.parse_src_or_dest(conf, "destination"),
+ }
+ rule.update(r_sub)
+ return rule
+
+ def parse_p2p(self, conf):
+ """
+ This function forms the regex to fetch the 'p2p' with in
+ 'rules'
+ :param conf: configuration data.
+ :return: generated rule list configuration.
+ """
+ a_lst = []
+ applications = findall(r"p2p (?:\'*)(\d+)(?:\'*)", conf, M)
+ if applications:
+ app_lst = []
+ for r in set(applications):
+ obj = {"application": r.strip("'")}
+ app_lst.append(obj)
+ a_lst = sorted(app_lst, key=lambda i: i["application"])
+ return a_lst
+
+ def parse_src_or_dest(self, conf, attrib=None):
+ """
+ This function triggers the parsing of 'source or
+ destination' attributes.
+ :param conf: configuration.
+ :param attrib:'source/destination'.
+ :return:generated source/destination configuration dictionary.
+ """
+ a_lst = ["port", "address", "mac_address"]
+ cfg_dict = self.parse_attr(conf, a_lst, match=attrib)
+ cfg_dict["group"] = self.parse_group(conf, attrib + " group")
+ return cfg_dict
+
+ def parse_recent(self, conf, attrib=None):
+ """
+ This function triggers the parsing of 'recent' attributes
+ :param conf: configuration.
+ :param attrib: 'recent'.
+ :return: generated config dictionary.
+ """
+ a_lst = ["time", "count"]
+ cfg_dict = self.parse_attr(conf, a_lst, match=attrib)
+ return cfg_dict
+
+ def parse_tcp(self, conf, attrib=None):
+ """
+ This function triggers the parsing of 'tcp' attributes.
+ :param conf: configuration.
+ :param attrib: 'tcp'.
+ :return: generated config dictionary.
+ """
+ cfg_dict = self.parse_attr(conf, ["flags"], match=attrib)
+ return cfg_dict
+
+ def parse_time(self, conf, attrib=None):
+ """
+ This function triggers the parsing of 'time' attributes.
+ :param conf: configuration.
+ :param attrib: 'time'.
+ :return: generated config dictionary.
+ """
+ a_lst = [
+ "stopdate",
+ "stoptime",
+ "weekdays",
+ "monthdays",
+ "startdate",
+ "starttime",
+ ]
+ cfg_dict = self.parse_attr(conf, a_lst, match=attrib)
+ return cfg_dict
+
+ def parse_state(self, conf, attrib=None):
+ """
+ This function triggers the parsing of 'state' attributes.
+ :param conf: configuration
+ :param attrib: 'state'.
+ :return: generated config dictionary.
+ """
+ a_lst = ["new", "invalid", "related", "established"]
+ cfg_dict = self.parse_attr(conf, a_lst, match=attrib)
+ return cfg_dict
+
+ def parse_group(self, conf, attrib=None):
+ """
+ This function triggers the parsing of 'group' attributes.
+ :param conf: configuration.
+ :param attrib: 'group'.
+ :return: generated config dictionary.
+ """
+ a_lst = ["port_group", "address_group", "network_group"]
+ cfg_dict = self.parse_attr(conf, a_lst, match=attrib)
+ return cfg_dict
+
+ def parse_icmp(self, conf, attrib=None):
+ """
+ This function triggers the parsing of 'icmp' attributes.
+ :param conf: configuration to be parsed.
+ :param attrib: 'icmp'.
+ :return: generated config dictionary.
+ """
+ a_lst = ["code", "type", "type_name"]
+ cfg_dict = self.parse_attr(conf, a_lst, match=attrib)
+ return cfg_dict
+
+ def parse_limit(self, conf, attrib=None):
+ """
+ This function triggers the parsing of 'limit' attributes.
+ :param conf: configuration to be parsed.
+ :param attrib: 'limit'
+ :return: generated config dictionary.
+ """
+ cfg_dict = self.parse_attr(conf, ["burst"], match=attrib)
+ cfg_dict["rate"] = self.parse_rate(conf, "rate")
+ return cfg_dict
+
+ def parse_rate(self, conf, attrib=None):
+ """
+ This function triggers the parsing of 'rate' attributes.
+ :param conf: configuration.
+ :param attrib: 'rate'
+ :return: generated config dictionary.
+ """
+ a_lst = ["unit", "number"]
+ cfg_dict = self.parse_attr(conf, a_lst, match=attrib)
+ return cfg_dict
+
+ def parse_attr(self, conf, attr_list, match=None):
+ """
+ This function peforms the following:
+ - Form the regex to fetch the required attribute config.
+ - Type cast the output in desired format.
+ :param conf: configuration.
+ :param attr_list: list of attributes.
+ :param match: parent node/attribute name.
+ :return: generated config dictionary.
+ """
+ config = {}
+ for attrib in attr_list:
+ regex = self.map_regex(attrib)
+ if match:
+ regex = match + " " + regex
+ if conf:
+ if self.is_bool(attrib):
+ out = conf.find(attrib.replace("_", "-"))
+
+ dis = conf.find(attrib.replace("_", "-") + " 'disable'")
+ if out >= 1:
+ if dis >= 1:
+ config[attrib] = False
+ else:
+ config[attrib] = True
+ else:
+ out = search(r"^.*" + regex + " (.+)", conf, M)
+ if out:
+ val = out.group(1).strip("'")
+ if self.is_num(attrib):
+ val = int(val)
+ config[attrib] = val
+ return config
+
+ def map_regex(self, attrib):
+ """
+ - This function construct the regex string.
+ - replace the underscore with hyphen.
+ :param attrib: attribute
+ :return: regex string
+ """
+ regex = attrib.replace("_", "-")
+ if attrib == "disabled":
+ regex = "disable"
+ return regex
+
+ def is_bool(self, attrib):
+ """
+ This function looks for the attribute in predefined bool type set.
+ :param attrib: attribute.
+ :return: True/False
+ """
+ bool_set = (
+ "new",
+ "invalid",
+ "related",
+ "disabled",
+ "established",
+ "enable_default_log",
+ )
+ return True if attrib in bool_set else False
+
+ def is_num(self, attrib):
+ """
+ This function looks for the attribute in predefined integer type set.
+ :param attrib: attribute.
+ :return: True/false.
+ """
+ num_set = ("time", "code", "type", "count", "burst", "number")
+ return True if attrib in num_set else False
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py
new file mode 100644
index 00000000..4b24803b
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py
@@ -0,0 +1,134 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The vyos interfaces fact class
+It is in this file the configuration is collected from the device
+for a given resource, parsed, and the facts tree is populated
+based on the configuration.
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+from re import findall, M
+from copy import deepcopy
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import (
+ utils,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.interfaces.interfaces import (
+ InterfacesArgs,
+)
+
+
+class InterfacesFacts(object):
+ """ The vyos interfaces fact class
+ """
+
+ def __init__(self, module, subspec="config", options="options"):
+ self._module = module
+ self.argument_spec = InterfacesArgs.argument_spec
+ spec = deepcopy(self.argument_spec)
+ if subspec:
+ if options:
+ facts_argument_spec = spec[subspec][options]
+ else:
+ facts_argument_spec = spec[subspec]
+ else:
+ facts_argument_spec = spec
+
+ self.generated_spec = utils.generate_dict(facts_argument_spec)
+
+ def populate_facts(self, connection, ansible_facts, data=None):
+ """ Populate the facts for interfaces
+ :param connection: the device connection
+ :param ansible_facts: Facts dictionary
+ :param data: previously collected conf
+ :rtype: dictionary
+ :returns: facts
+ """
+ if not data:
+ data = connection.get_config(flags=["| grep interfaces"])
+
+ objs = []
+ interface_names = findall(
+ r"^set interfaces (?:ethernet|bonding|vti|loopback|vxlan) (?:\'*)(\S+)(?:\'*)",
+ data,
+ M,
+ )
+ if interface_names:
+ for interface in set(interface_names):
+ intf_regex = r" %s .+$" % interface.strip("'")
+ cfg = findall(intf_regex, data, M)
+ obj = self.render_config(cfg)
+ obj["name"] = interface.strip("'")
+ if obj:
+ objs.append(obj)
+ facts = {}
+ if objs:
+ facts["interfaces"] = []
+ params = utils.validate_config(
+ self.argument_spec, {"config": objs}
+ )
+ for cfg in params["config"]:
+ facts["interfaces"].append(utils.remove_empties(cfg))
+
+ ansible_facts["ansible_network_resources"].update(facts)
+ return ansible_facts
+
+ def render_config(self, conf):
+ """
+ Render config as dictionary structure and delete keys
+ from spec for null values
+
+ :param spec: The facts tree, generated from the argspec
+ :param conf: The configuration
+ :rtype: dictionary
+ :returns: The generated config
+ """
+ vif_conf = "\n".join(filter(lambda x: ("vif" in x), conf))
+ eth_conf = "\n".join(filter(lambda x: ("vif" not in x), conf))
+ config = self.parse_attribs(
+ ["description", "speed", "mtu", "duplex"], eth_conf
+ )
+ config["vifs"] = self.parse_vifs(vif_conf)
+
+ return utils.remove_empties(config)
+
+ def parse_vifs(self, conf):
+ vif_names = findall(r"vif (?:\'*)(\d+)(?:\'*)", conf, M)
+ vifs_list = None
+
+ if vif_names:
+ vifs_list = []
+ for vif in set(vif_names):
+ vif_regex = r" %s .+$" % vif
+ cfg = "\n".join(findall(vif_regex, conf, M))
+ obj = self.parse_attribs(["description", "mtu"], cfg)
+ obj["vlan_id"] = int(vif)
+ if obj:
+ vifs_list.append(obj)
+ vifs_list = sorted(vifs_list, key=lambda i: i["vlan_id"])
+
+ return vifs_list
+
+ def parse_attribs(self, attribs, conf):
+ config = {}
+ for item in attribs:
+ value = utils.parse_conf_arg(conf, item)
+ if value and item == "mtu":
+ config[item] = int(value.strip("'"))
+ elif value:
+ config[item] = value.strip("'")
+ else:
+ config[item] = None
+ if "disable" in conf:
+ config["enabled"] = False
+ else:
+ config["enabled"] = True
+
+ return utils.remove_empties(config)
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py
new file mode 100644
index 00000000..d1d62c23
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py
@@ -0,0 +1,143 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The vyos l3_interfaces fact class
+It is in this file the configuration is collected from the device
+for a given resource, parsed, and the facts tree is populated
+based on the configuration.
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+import re
+from copy import deepcopy
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import (
+ utils,
+)
+from ansible.module_utils.six import iteritems
+from ansible_collections.ansible.netcommon.plugins.module_utils.compat import (
+ ipaddress,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.l3_interfaces.l3_interfaces import (
+ L3_interfacesArgs,
+)
+
+
+class L3_interfacesFacts(object):
+ """ The vyos l3_interfaces fact class
+ """
+
+ def __init__(self, module, subspec="config", options="options"):
+ self._module = module
+ self.argument_spec = L3_interfacesArgs.argument_spec
+ spec = deepcopy(self.argument_spec)
+ if subspec:
+ if options:
+ facts_argument_spec = spec[subspec][options]
+ else:
+ facts_argument_spec = spec[subspec]
+ else:
+ facts_argument_spec = spec
+
+ self.generated_spec = utils.generate_dict(facts_argument_spec)
+
+ def populate_facts(self, connection, ansible_facts, data=None):
+ """ Populate the facts for l3_interfaces
+ :param connection: the device connection
+ :param ansible_facts: Facts dictionary
+ :param data: previously collected conf
+ :rtype: dictionary
+ :returns: facts
+ """
+ if not data:
+ data = connection.get_config()
+
+ # operate on a collection of resource x
+ objs = []
+ interface_names = re.findall(
+ r"set interfaces (?:ethernet|bonding|vti|vxlan) (?:\'*)(\S+)(?:\'*)",
+ data,
+ re.M,
+ )
+ if interface_names:
+ for interface in set(interface_names):
+ intf_regex = r" %s .+$" % interface
+ cfg = re.findall(intf_regex, data, re.M)
+ obj = self.render_config(cfg)
+ obj["name"] = interface.strip("'")
+ if obj:
+ objs.append(obj)
+
+ ansible_facts["ansible_network_resources"].pop("l3_interfaces", None)
+ facts = {}
+ if objs:
+ facts["l3_interfaces"] = []
+ params = utils.validate_config(
+ self.argument_spec, {"config": objs}
+ )
+ for cfg in params["config"]:
+ facts["l3_interfaces"].append(utils.remove_empties(cfg))
+
+ ansible_facts["ansible_network_resources"].update(facts)
+ return ansible_facts
+
+ def render_config(self, conf):
+ """
+ Render config as dictionary structure and delete keys from spec for null values
+ :param spec: The facts tree, generated from the argspec
+ :param conf: The configuration
+ :rtype: dictionary
+ :returns: The generated config
+ """
+ vif_conf = "\n".join(filter(lambda x: ("vif" in x), conf))
+ eth_conf = "\n".join(filter(lambda x: ("vif" not in x), conf))
+ config = self.parse_attribs(eth_conf)
+ config["vifs"] = self.parse_vifs(vif_conf)
+
+ return utils.remove_empties(config)
+
+ def parse_vifs(self, conf):
+ vif_names = re.findall(r"vif (\d+)", conf, re.M)
+ vifs_list = None
+ if vif_names:
+ vifs_list = []
+ for vif in set(vif_names):
+ vif_regex = r" %s .+$" % vif
+ cfg = "\n".join(re.findall(vif_regex, conf, re.M))
+ obj = self.parse_attribs(cfg)
+ obj["vlan_id"] = vif
+ if obj:
+ vifs_list.append(obj)
+
+ return vifs_list
+
+ def parse_attribs(self, conf):
+ config = {}
+ ipaddrs = re.findall(r"address (\S+)", conf, re.M)
+ config["ipv4"] = []
+ config["ipv6"] = []
+
+ for item in ipaddrs:
+ item = item.strip("'")
+ if item == "dhcp":
+ config["ipv4"].append({"address": item})
+ elif item == "dhcpv6":
+ config["ipv6"].append({"address": item})
+ else:
+ ip_version = ipaddress.ip_address(item.split("/")[0]).version
+ if ip_version == 4:
+ config["ipv4"].append({"address": item})
+ else:
+ config["ipv6"].append({"address": item})
+
+ for key, value in iteritems(config):
+ if value == []:
+ config[key] = None
+
+ return utils.remove_empties(config)
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py
new file mode 100644
index 00000000..9201e5c6
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py
@@ -0,0 +1,152 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The vyos lag_interfaces fact class
+It is in this file the configuration is collected from the device
+for a given resource, parsed, and the facts tree is populated
+based on the configuration.
+"""
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+from re import findall, search, M
+from copy import deepcopy
+
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import (
+ utils,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lag_interfaces.lag_interfaces import (
+ Lag_interfacesArgs,
+)
+
+
+class Lag_interfacesFacts(object):
+ """ The vyos lag_interfaces fact class
+ """
+
+ def __init__(self, module, subspec="config", options="options"):
+ self._module = module
+ self.argument_spec = Lag_interfacesArgs.argument_spec
+ spec = deepcopy(self.argument_spec)
+ if subspec:
+ if options:
+ facts_argument_spec = spec[subspec][options]
+ else:
+ facts_argument_spec = spec[subspec]
+ else:
+ facts_argument_spec = spec
+
+ self.generated_spec = utils.generate_dict(facts_argument_spec)
+
+ def populate_facts(self, connection, ansible_facts, data=None):
+ """ Populate the facts for lag_interfaces
+ :param module: the module instance
+ :param connection: the device connection
+ :param data: previously collected conf
+ :rtype: dictionary
+ :returns: facts
+ """
+ if not data:
+ data = connection.get_config()
+
+ objs = []
+ lag_names = findall(r"^set interfaces bonding (\S+)", data, M)
+ if lag_names:
+ for lag in set(lag_names):
+ lag_regex = r" %s .+$" % lag
+ cfg = findall(lag_regex, data, M)
+ obj = self.render_config(cfg)
+
+ output = connection.run_commands(
+ ["show interfaces bonding " + lag + " slaves"]
+ )
+ lines = output[0].splitlines()
+ members = []
+ member = {}
+ if len(lines) > 1:
+ for line in lines[2:]:
+ splitted_line = line.split()
+
+ if len(splitted_line) > 1:
+ member["member"] = splitted_line[0]
+ members.append(member)
+ else:
+ members = []
+ member = {}
+ obj["name"] = lag.strip("'")
+ if members:
+ obj["members"] = members
+
+ if obj:
+ objs.append(obj)
+
+ facts = {}
+ if objs:
+ facts["lag_interfaces"] = []
+ params = utils.validate_config(
+ self.argument_spec, {"config": objs}
+ )
+ for cfg in params["config"]:
+ facts["lag_interfaces"].append(utils.remove_empties(cfg))
+
+ ansible_facts["ansible_network_resources"].update(facts)
+ return ansible_facts
+
+ def render_config(self, conf):
+ """
+ Render config as dictionary structure and delete keys
+ from spec for null values
+
+ :param spec: The facts tree, generated from the argspec
+ :param conf: The configuration
+ :rtype: dictionary
+ :returns: The generated config
+ """
+ arp_monitor_conf = "\n".join(
+ filter(lambda x: ("arp-monitor" in x), conf)
+ )
+ hash_policy_conf = "\n".join(
+ filter(lambda x: ("hash-policy" in x), conf)
+ )
+ lag_conf = "\n".join(filter(lambda x: ("bond" in x), conf))
+ config = self.parse_attribs(["mode", "primary"], lag_conf)
+ config["arp_monitor"] = self.parse_arp_monitor(arp_monitor_conf)
+ config["hash_policy"] = self.parse_hash_policy(hash_policy_conf)
+
+ return utils.remove_empties(config)
+
+ def parse_attribs(self, attribs, conf):
+ config = {}
+ for item in attribs:
+ value = utils.parse_conf_arg(conf, item)
+ if value:
+ config[item] = value.strip("'")
+ else:
+ config[item] = None
+ return utils.remove_empties(config)
+
+ def parse_arp_monitor(self, conf):
+ arp_monitor = None
+ if conf:
+ arp_monitor = {}
+ target_list = []
+ interval = search(r"^.*arp-monitor interval (.+)", conf, M)
+ targets = findall(r"^.*arp-monitor target '(.+)'", conf, M)
+ if targets:
+ for target in targets:
+ target_list.append(target)
+ arp_monitor["target"] = target_list
+ if interval:
+ value = interval.group(1).strip("'")
+ arp_monitor["interval"] = int(value)
+ return arp_monitor
+
+ def parse_hash_policy(self, conf):
+ hash_policy = None
+ if conf:
+ hash_policy = search(r"^.*hash-policy (.+)", conf, M)
+ hash_policy = hash_policy.group(1).strip("'")
+ return hash_policy
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py
new file mode 100644
index 00000000..f6b343e0
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py
@@ -0,0 +1,162 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The VyOS interfaces fact class
+It is in this file the configuration is collected from the device
+for a given resource, parsed, and the facts tree is populated
+based on the configuration.
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+import platform
+import re
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import (
+ run_commands,
+ get_capabilities,
+)
+
+
+class LegacyFactsBase(object):
+
+ COMMANDS = frozenset()
+
+ def __init__(self, module):
+ self.module = module
+ self.facts = dict()
+ self.warnings = list()
+ self.responses = None
+
+ def populate(self):
+ self.responses = run_commands(self.module, list(self.COMMANDS))
+
+
+class Default(LegacyFactsBase):
+
+ COMMANDS = [
+ "show version",
+ ]
+
+ def populate(self):
+ super(Default, self).populate()
+ data = self.responses[0]
+ self.facts["serialnum"] = self.parse_serialnum(data)
+ self.facts.update(self.platform_facts())
+
+ def parse_serialnum(self, data):
+ match = re.search(r"HW S/N:\s+(\S+)", data)
+ if match:
+ return match.group(1)
+
+ def platform_facts(self):
+ platform_facts = {}
+
+ resp = get_capabilities(self.module)
+ device_info = resp["device_info"]
+
+ platform_facts["system"] = device_info["network_os"]
+
+ for item in ("model", "image", "version", "platform", "hostname"):
+ val = device_info.get("network_os_%s" % item)
+ if val:
+ platform_facts[item] = val
+
+ platform_facts["api"] = resp["network_api"]
+ platform_facts["python_version"] = platform.python_version()
+
+ return platform_facts
+
+
+class Config(LegacyFactsBase):
+
+ COMMANDS = [
+ "show configuration commands",
+ "show system commit",
+ ]
+
+ def populate(self):
+ super(Config, self).populate()
+
+ self.facts["config"] = self.responses
+
+ commits = self.responses[1]
+ entries = list()
+ entry = None
+
+ for line in commits.split("\n"):
+ match = re.match(r"(\d+)\s+(.+)by(.+)via(.+)", line)
+ if match:
+ if entry:
+ entries.append(entry)
+
+ entry = dict(
+ revision=match.group(1),
+ datetime=match.group(2),
+ by=str(match.group(3)).strip(),
+ via=str(match.group(4)).strip(),
+ comment=None,
+ )
+ else:
+ entry["comment"] = line.strip()
+
+ self.facts["commits"] = entries
+
+
+class Neighbors(LegacyFactsBase):
+
+ COMMANDS = [
+ "show lldp neighbors",
+ "show lldp neighbors detail",
+ ]
+
+ def populate(self):
+ super(Neighbors, self).populate()
+
+ all_neighbors = self.responses[0]
+ if "LLDP not configured" not in all_neighbors:
+ neighbors = self.parse(self.responses[1])
+ self.facts["neighbors"] = self.parse_neighbors(neighbors)
+
+ def parse(self, data):
+ parsed = list()
+ values = None
+ for line in data.split("\n"):
+ if not line:
+ continue
+ elif line[0] == " ":
+ values += "\n%s" % line
+ elif line.startswith("Interface"):
+ if values:
+ parsed.append(values)
+ values = line
+ if values:
+ parsed.append(values)
+ return parsed
+
+ def parse_neighbors(self, data):
+ facts = dict()
+ for item in data:
+ interface = self.parse_interface(item)
+ host = self.parse_host(item)
+ port = self.parse_port(item)
+ if interface not in facts:
+ facts[interface] = list()
+ facts[interface].append(dict(host=host, port=port))
+ return facts
+
+ def parse_interface(self, data):
+ match = re.search(r"^Interface:\s+(\S+),", data)
+ return match.group(1)
+
+ def parse_host(self, data):
+ match = re.search(r"SysName:\s+(.+)$", data, re.M)
+ if match:
+ return match.group(1)
+
+ def parse_port(self, data):
+ match = re.search(r"PortDescr:\s+(.+)$", data, re.M)
+ if match:
+ return match.group(1)
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py
new file mode 100644
index 00000000..3c7e2f93
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py
@@ -0,0 +1,116 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The vyos lldp_global fact class
+It is in this file the configuration is collected from the device
+for a given resource, parsed, and the facts tree is populated
+based on the configuration.
+"""
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+from re import findall, M
+from copy import deepcopy
+
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import (
+ utils,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lldp_global.lldp_global import (
+ Lldp_globalArgs,
+)
+
+
+class Lldp_globalFacts(object):
+ """ The vyos lldp_global fact class
+ """
+
+ def __init__(self, module, subspec="config", options="options"):
+ self._module = module
+ self.argument_spec = Lldp_globalArgs.argument_spec
+ spec = deepcopy(self.argument_spec)
+ if subspec:
+ if options:
+ facts_argument_spec = spec[subspec][options]
+ else:
+ facts_argument_spec = spec[subspec]
+ else:
+ facts_argument_spec = spec
+
+ self.generated_spec = utils.generate_dict(facts_argument_spec)
+
+ def populate_facts(self, connection, ansible_facts, data=None):
+ """ Populate the facts for lldp_global
+ :param connection: the device connection
+ :param ansible_facts: Facts dictionary
+ :param data: previously collected conf
+ :rtype: dictionary
+ :returns: facts
+ """
+ if not data:
+ data = connection.get_config()
+
+ objs = {}
+ lldp_output = findall(r"^set service lldp (\S+)", data, M)
+ if lldp_output:
+ for item in set(lldp_output):
+ lldp_regex = r" %s .+$" % item
+ cfg = findall(lldp_regex, data, M)
+ obj = self.render_config(cfg)
+ if obj:
+ objs.update(obj)
+ lldp_service = findall(r"^set service (lldp)?('lldp')", data, M)
+ if lldp_service or lldp_output:
+ lldp_obj = {}
+ lldp_obj["enable"] = True
+ objs.update(lldp_obj)
+
+ facts = {}
+ params = utils.validate_config(self.argument_spec, {"config": objs})
+ facts["lldp_global"] = utils.remove_empties(params["config"])
+
+ ansible_facts["ansible_network_resources"].update(facts)
+
+ return ansible_facts
+
+ def render_config(self, conf):
+ """
+ Render config as dictionary structure and delete keys
+ from spec for null values
+ :param spec: The facts tree, generated from the argspec
+ :param conf: The configuration
+ :rtype: dictionary
+ :returns: The generated config
+ """
+ protocol_conf = "\n".join(
+ filter(lambda x: ("legacy-protocols" in x), conf)
+ )
+ att_conf = "\n".join(
+ filter(lambda x: ("legacy-protocols" not in x), conf)
+ )
+ config = self.parse_attribs(["snmp", "address"], att_conf)
+ config["legacy_protocols"] = self.parse_protocols(protocol_conf)
+ return utils.remove_empties(config)
+
+ def parse_protocols(self, conf):
+ protocol_support = None
+ if conf:
+ protocols = findall(r"^.*legacy-protocols (.+)", conf, M)
+ if protocols:
+ protocol_support = []
+ for protocol in protocols:
+ protocol_support.append(protocol.strip("'"))
+ return protocol_support
+
+ def parse_attribs(self, attribs, conf):
+ config = {}
+ for item in attribs:
+ value = utils.parse_conf_arg(conf, item)
+ if value:
+ config[item] = value.strip("'")
+ else:
+ config[item] = None
+ return utils.remove_empties(config)
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py
new file mode 100644
index 00000000..dcfbc6ee
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py
@@ -0,0 +1,155 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The vyos lldp_interfaces fact class
+It is in this file the configuration is collected from the device
+for a given resource, parsed, and the facts tree is populated
+based on the configuration.
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+
+from re import findall, search, M
+from copy import deepcopy
+
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import (
+ utils,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lldp_interfaces.lldp_interfaces import (
+ Lldp_interfacesArgs,
+)
+
+
+class Lldp_interfacesFacts(object):
+ """ The vyos lldp_interfaces fact class
+ """
+
+ def __init__(self, module, subspec="config", options="options"):
+ self._module = module
+ self.argument_spec = Lldp_interfacesArgs.argument_spec
+ spec = deepcopy(self.argument_spec)
+ if subspec:
+ if options:
+ facts_argument_spec = spec[subspec][options]
+ else:
+ facts_argument_spec = spec[subspec]
+ else:
+ facts_argument_spec = spec
+
+ self.generated_spec = utils.generate_dict(facts_argument_spec)
+
+ def populate_facts(self, connection, ansible_facts, data=None):
+ """ Populate the facts for lldp_interfaces
+ :param connection: the device connection
+ :param ansible_facts: Facts dictionary
+ :param data: previously collected conf
+ :rtype: dictionary
+ :returns: facts
+ """
+ if not data:
+ data = connection.get_config()
+
+ objs = []
+ lldp_names = findall(r"^set service lldp interface (\S+)", data, M)
+ if lldp_names:
+ for lldp in set(lldp_names):
+ lldp_regex = r" %s .+$" % lldp
+ cfg = findall(lldp_regex, data, M)
+ obj = self.render_config(cfg)
+ obj["name"] = lldp.strip("'")
+ if obj:
+ objs.append(obj)
+ facts = {}
+ if objs:
+ facts["lldp_interfaces"] = objs
+ ansible_facts["ansible_network_resources"].update(facts)
+
+ ansible_facts["ansible_network_resources"].update(facts)
+ return ansible_facts
+
+ def render_config(self, conf):
+ """
+ Render config as dictionary structure and delete keys
+ from spec for null values
+
+ :param spec: The facts tree, generated from the argspec
+ :param conf: The configuration
+ :rtype: dictionary
+ :returns: The generated config
+ """
+ config = {}
+ location = {}
+
+ civic_conf = "\n".join(filter(lambda x: ("civic-based" in x), conf))
+ elin_conf = "\n".join(filter(lambda x: ("elin" in x), conf))
+ coordinate_conf = "\n".join(
+ filter(lambda x: ("coordinate-based" in x), conf)
+ )
+ disable = "\n".join(filter(lambda x: ("disable" in x), conf))
+
+ coordinate_based_conf = self.parse_attribs(
+ ["altitude", "datum", "longitude", "latitude"], coordinate_conf
+ )
+ elin_based_conf = self.parse_lldp_elin_based(elin_conf)
+ civic_based_conf = self.parse_lldp_civic_based(civic_conf)
+ if disable:
+ config["enable"] = False
+ if coordinate_conf:
+ location["coordinate_based"] = coordinate_based_conf
+ config["location"] = location
+ elif civic_based_conf:
+ location["civic_based"] = civic_based_conf
+ config["location"] = location
+ elif elin_conf:
+ location["elin"] = elin_based_conf
+ config["location"] = location
+
+ return utils.remove_empties(config)
+
+ def parse_attribs(self, attribs, conf):
+ config = {}
+ for item in attribs:
+ value = utils.parse_conf_arg(conf, item)
+ if value:
+ value = value.strip("'")
+ if item == "altitude":
+ value = int(value)
+ config[item] = value
+ else:
+ config[item] = None
+ return utils.remove_empties(config)
+
+ def parse_lldp_civic_based(self, conf):
+ civic_based = None
+ if conf:
+ civic_info_list = []
+ civic_add_list = findall(r"^.*civic-based ca-type (.+)", conf, M)
+ if civic_add_list:
+ for civic_add in civic_add_list:
+ ca = civic_add.split(" ")
+ c_add = {}
+ c_add["ca_type"] = int(ca[0].strip("'"))
+ c_add["ca_value"] = ca[2].strip("'")
+ civic_info_list.append(c_add)
+
+ country_code = search(
+ r"^.*civic-based country-code (.+)", conf, M
+ )
+ civic_based = {}
+ civic_based["ca_info"] = civic_info_list
+ civic_based["country_code"] = country_code.group(1).strip("'")
+ return civic_based
+
+ def parse_lldp_elin_based(self, conf):
+ elin_based = None
+ if conf:
+ e_num = search(r"^.* elin (.+)", conf, M)
+ elin_based = e_num.group(1).strip("'")
+
+ return elin_based
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py
new file mode 100644
index 00000000..00049475
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py
@@ -0,0 +1,181 @@
+#
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The vyos static_routes fact class
+It is in this file the configuration is collected from the device
+for a given resource, parsed, and the facts tree is populated
+based on the configuration.
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+from re import findall, search, M
+from copy import deepcopy
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import (
+ utils,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.static_routes.static_routes import (
+ Static_routesArgs,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.utils.utils import (
+ get_route_type,
+)
+
+
+class Static_routesFacts(object):
+ """ The vyos static_routes fact class
+ """
+
+ def __init__(self, module, subspec="config", options="options"):
+ self._module = module
+ self.argument_spec = Static_routesArgs.argument_spec
+ spec = deepcopy(self.argument_spec)
+ if subspec:
+ if options:
+ facts_argument_spec = spec[subspec][options]
+ else:
+ facts_argument_spec = spec[subspec]
+ else:
+ facts_argument_spec = spec
+
+ self.generated_spec = utils.generate_dict(facts_argument_spec)
+
+ def get_device_data(self, connection):
+ return connection.get_config()
+
+ def populate_facts(self, connection, ansible_facts, data=None):
+ """ Populate the facts for static_routes
+ :param connection: the device connection
+ :param ansible_facts: Facts dictionary
+ :param data: previously collected conf
+ :rtype: dictionary
+ :returns: facts
+ """
+ if not data:
+ data = self.get_device_data(connection)
+ # typically data is populated from the current device configuration
+ # data = connection.get('show running-config | section ^interface')
+ # using mock data instead
+ objs = []
+ r_v4 = []
+ r_v6 = []
+ af = []
+ static_routes = findall(
+ r"set protocols static route(6)? (\S+)", data, M
+ )
+ if static_routes:
+ for route in set(static_routes):
+ route_regex = r" %s .+$" % route[1]
+ cfg = findall(route_regex, data, M)
+ sr = self.render_config(cfg)
+ sr["dest"] = route[1].strip("'")
+ afi = self.get_afi(sr["dest"])
+ if afi == "ipv4":
+ r_v4.append(sr)
+ else:
+ r_v6.append(sr)
+ if r_v4:
+ afi_v4 = {"afi": "ipv4", "routes": r_v4}
+ af.append(afi_v4)
+ if r_v6:
+ afi_v6 = {"afi": "ipv6", "routes": r_v6}
+ af.append(afi_v6)
+ config = {"address_families": af}
+ if config:
+ objs.append(config)
+
+ ansible_facts["ansible_network_resources"].pop("static_routes", None)
+ facts = {}
+ if objs:
+ facts["static_routes"] = []
+ params = utils.validate_config(
+ self.argument_spec, {"config": objs}
+ )
+ for cfg in params["config"]:
+ facts["static_routes"].append(utils.remove_empties(cfg))
+
+ ansible_facts["ansible_network_resources"].update(facts)
+ return ansible_facts
+
+ def render_config(self, conf):
+ """
+ Render config as dictionary structure and delete keys
+ from spec for null values
+
+ :param spec: The facts tree, generated from the argspec
+ :param conf: The configuration
+ :rtype: dictionary
+ :returns: The generated config
+ """
+ next_hops_conf = "\n".join(filter(lambda x: ("next-hop" in x), conf))
+ blackhole_conf = "\n".join(filter(lambda x: ("blackhole" in x), conf))
+ routes_dict = {
+ "blackhole_config": self.parse_blackhole(blackhole_conf),
+ "next_hops": self.parse_next_hop(next_hops_conf),
+ }
+ return routes_dict
+
+ def parse_blackhole(self, conf):
+ blackhole = None
+ if conf:
+ distance = search(r"^.*blackhole distance (.\S+)", conf, M)
+ bh = conf.find("blackhole")
+ if distance is not None:
+ blackhole = {}
+ value = distance.group(1).strip("'")
+ blackhole["distance"] = int(value)
+ elif bh:
+ blackhole = {}
+ blackhole["type"] = "blackhole"
+ return blackhole
+
+ def get_afi(self, address):
+ route_type = get_route_type(address)
+ if route_type == "route":
+ return "ipv4"
+ elif route_type == "route6":
+ return "ipv6"
+
+ def parse_next_hop(self, conf):
+ nh_list = None
+ if conf:
+ nh_list = []
+ hop_list = findall(r"^.*next-hop (.+)", conf, M)
+ if hop_list:
+ for hop in hop_list:
+ distance = search(r"^.*distance (.\S+)", hop, M)
+ interface = search(r"^.*interface (.\S+)", hop, M)
+
+ dis = hop.find("disable")
+ hop_info = hop.split(" ")
+ nh_info = {
+ "forward_router_address": hop_info[0].strip("'")
+ }
+ if interface:
+ nh_info["interface"] = interface.group(1).strip("'")
+ if distance:
+ value = distance.group(1).strip("'")
+ nh_info["admin_distance"] = int(value)
+ elif dis >= 1:
+ nh_info["enabled"] = False
+ for element in nh_list:
+ if (
+ element["forward_router_address"]
+ == nh_info["forward_router_address"]
+ ):
+ if "interface" in nh_info.keys():
+ element["interface"] = nh_info["interface"]
+ if "admin_distance" in nh_info.keys():
+ element["admin_distance"] = nh_info[
+ "admin_distance"
+ ]
+ if "enabled" in nh_info.keys():
+ element["enabled"] = nh_info["enabled"]
+ nh_info = None
+ if nh_info is not None:
+ nh_list.append(nh_info)
+ return nh_list
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py
new file mode 100644
index 00000000..402adfc9
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py
@@ -0,0 +1,231 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# utils
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+from ansible.module_utils.six import iteritems
+from ansible_collections.ansible.netcommon.plugins.module_utils.compat import (
+ ipaddress,
+)
+
+
+def search_obj_in_list(name, lst, key="name"):
+ for item in lst:
+ if item[key] == name:
+ return item
+ return None
+
+
+def get_interface_type(interface):
+ """Gets the type of interface
+ """
+ if interface.startswith("eth"):
+ return "ethernet"
+ elif interface.startswith("bond"):
+ return "bonding"
+ elif interface.startswith("vti"):
+ return "vti"
+ elif interface.startswith("lo"):
+ return "loopback"
+
+
+def dict_delete(base, comparable):
+ """
+ This function generates a dict containing key, value pairs for keys
+ that are present in the `base` dict but not present in the `comparable`
+ dict.
+
+ :param base: dict object to base the diff on
+ :param comparable: dict object to compare against base
+ :returns: new dict object with key, value pairs that needs to be deleted.
+
+ """
+ to_delete = dict()
+
+ for key in base:
+ if isinstance(base[key], dict):
+ sub_diff = dict_delete(base[key], comparable.get(key, {}))
+ if sub_diff:
+ to_delete[key] = sub_diff
+ else:
+ if key not in comparable:
+ to_delete[key] = base[key]
+
+ return to_delete
+
+
+def diff_list_of_dicts(want, have):
+ diff = []
+
+ set_w = set(tuple(d.items()) for d in want)
+ set_h = set(tuple(d.items()) for d in have)
+ difference = set_w.difference(set_h)
+
+ for element in difference:
+ diff.append(dict((x, y) for x, y in element))
+
+ return diff
+
+
+def get_lst_diff_for_dicts(want, have, lst):
+ """
+ This function generates a list containing values
+ that are only in want and not in list in have dict
+ :param want: dict object to want
+ :param have: dict object to have
+ :param lst: list the diff on
+ :return: new list object with values which are only in want.
+ """
+ if not have:
+ diff = want.get(lst) or []
+
+ else:
+ want_elements = want.get(lst) or {}
+ have_elements = have.get(lst) or {}
+ diff = list_diff_want_only(want_elements, have_elements)
+ return diff
+
+
+def get_lst_same_for_dicts(want, have, lst):
+ """
+ This function generates a list containing values
+ that are common for list in want and list in have dict
+ :param want: dict object to want
+ :param have: dict object to have
+ :param lst: list the comparison on
+ :return: new list object with values which are common in want and have.
+ """
+ diff = None
+ if want and have:
+ want_list = want.get(lst) or {}
+ have_list = have.get(lst) or {}
+ diff = [
+ i
+ for i in want_list and have_list
+ if i in have_list and i in want_list
+ ]
+ return diff
+
+
+def list_diff_have_only(want_list, have_list):
+ """
+ This function generated the list containing values
+ that are only in have list.
+ :param want_list:
+ :param have_list:
+ :return: new list with values which are only in have list
+ """
+ if have_list and not want_list:
+ diff = have_list
+ elif not have_list:
+ diff = None
+ else:
+ diff = [
+ i
+ for i in have_list + want_list
+ if i in have_list and i not in want_list
+ ]
+ return diff
+
+
+def list_diff_want_only(want_list, have_list):
+ """
+ This function generated the list containing values
+ that are only in want list.
+ :param want_list:
+ :param have_list:
+ :return: new list with values which are only in want list
+ """
+ if have_list and not want_list:
+ diff = None
+ elif not have_list:
+ diff = want_list
+ else:
+ diff = [
+ i
+ for i in have_list + want_list
+ if i in want_list and i not in have_list
+ ]
+ return diff
+
+
+def search_dict_tv_in_list(d_val1, d_val2, lst, key1, key2):
+ """
+ This function return the dict object if it exist in list.
+ :param d_val1:
+ :param d_val2:
+ :param lst:
+ :param key1:
+ :param key2:
+ :return:
+ """
+ obj = next(
+ (
+ item
+ for item in lst
+ if item[key1] == d_val1 and item[key2] == d_val2
+ ),
+ None,
+ )
+ if obj:
+ return obj
+ else:
+ return None
+
+
+def key_value_in_dict(have_key, have_value, want_dict):
+ """
+ This function checks whether the key and values exist in dict
+ :param have_key:
+ :param have_value:
+ :param want_dict:
+ :return:
+ """
+ for key, value in iteritems(want_dict):
+ if key == have_key and value == have_value:
+ return True
+ return False
+
+
+def is_dict_element_present(dict, key):
+ """
+ This function checks whether the key is present in dict.
+ :param dict:
+ :param key:
+ :return:
+ """
+ for item in dict:
+ if item == key:
+ return True
+ return False
+
+
+def get_ip_address_version(address):
+ """
+ This function returns the version of IP address
+ :param address: IP address
+ :return:
+ """
+ try:
+ address = unicode(address)
+ except NameError:
+ address = str(address)
+ version = ipaddress.ip_address(address.split("/")[0]).version
+ return version
+
+
+def get_route_type(address):
+ """
+ This function returns the route type based on IP address
+ :param address:
+ :return:
+ """
+ version = get_ip_address_version(address)
+ if version == 6:
+ return "route6"
+ elif version == 4:
+ return "route"
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py
new file mode 100644
index 00000000..908395a6
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py
@@ -0,0 +1,124 @@
+# This code is part of Ansible, but is an independent component.
+# This particular file snippet, and this file snippet only, is BSD licensed.
+# Modules you write using this snippet, which is embedded dynamically by Ansible
+# still belong to the author of the module, and may assign their own license
+# to the complete work.
+#
+# (c) 2016 Red Hat Inc.
+#
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+import json
+
+from ansible.module_utils._text import to_text
+from ansible.module_utils.basic import env_fallback
+from ansible.module_utils.connection import Connection, ConnectionError
+
+_DEVICE_CONFIGS = {}
+
+vyos_provider_spec = {
+ "host": dict(),
+ "port": dict(type="int"),
+ "username": dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])),
+ "password": dict(
+ fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"]), no_log=True
+ ),
+ "ssh_keyfile": dict(
+ fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path"
+ ),
+ "timeout": dict(type="int"),
+}
+vyos_argument_spec = {
+ "provider": dict(
+ type="dict", options=vyos_provider_spec, removed_in_version=2.14
+ ),
+}
+
+
+def get_provider_argspec():
+ return vyos_provider_spec
+
+
+def get_connection(module):
+ if hasattr(module, "_vyos_connection"):
+ return module._vyos_connection
+
+ capabilities = get_capabilities(module)
+ network_api = capabilities.get("network_api")
+ if network_api == "cliconf":
+ module._vyos_connection = Connection(module._socket_path)
+ else:
+ module.fail_json(msg="Invalid connection type %s" % network_api)
+
+ return module._vyos_connection
+
+
+def get_capabilities(module):
+ if hasattr(module, "_vyos_capabilities"):
+ return module._vyos_capabilities
+
+ try:
+ capabilities = Connection(module._socket_path).get_capabilities()
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc, errors="surrogate_then_replace"))
+
+ module._vyos_capabilities = json.loads(capabilities)
+ return module._vyos_capabilities
+
+
+def get_config(module, flags=None, format=None):
+ flags = [] if flags is None else flags
+ global _DEVICE_CONFIGS
+
+ if _DEVICE_CONFIGS != {}:
+ return _DEVICE_CONFIGS
+ else:
+ connection = get_connection(module)
+ try:
+ out = connection.get_config(flags=flags, format=format)
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc, errors="surrogate_then_replace"))
+ cfg = to_text(out, errors="surrogate_then_replace").strip()
+ _DEVICE_CONFIGS = cfg
+ return cfg
+
+
+def run_commands(module, commands, check_rc=True):
+ connection = get_connection(module)
+ try:
+ response = connection.run_commands(
+ commands=commands, check_rc=check_rc
+ )
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc, errors="surrogate_then_replace"))
+ return response
+
+
+def load_config(module, commands, commit=False, comment=None):
+ connection = get_connection(module)
+
+ try:
+ response = connection.edit_config(
+ candidate=commands, commit=commit, comment=comment
+ )
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc, errors="surrogate_then_replace"))
+
+ return response.get("diff")
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py
new file mode 100644
index 00000000..18538491
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py
@@ -0,0 +1,223 @@
+#!/usr/bin/python
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+
+ANSIBLE_METADATA = {
+ "metadata_version": "1.1",
+ "status": ["preview"],
+ "supported_by": "network",
+}
+
+
+DOCUMENTATION = """module: vyos_command
+author: Nathaniel Case (@Qalthos)
+short_description: Run one or more commands on VyOS devices
+description:
+- The command module allows running one or more commands on remote devices running
+ VyOS. This module can also be introspected to validate key parameters before returning
+ successfully. If the conditional statements are not met in the wait period, the
+ task fails.
+- Certain C(show) commands in VyOS produce many lines of output and use a custom pager
+ that can cause this module to hang. If the value of the environment variable C(ANSIBLE_VYOS_TERMINAL_LENGTH)
+ is not set, the default number of 10000 is used.
+extends_documentation_fragment:
+- vyos.vyos.vyos
+options:
+ commands:
+ description:
+ - The ordered set of commands to execute on the remote device running VyOS. The
+ output from the command execution is returned to the playbook. If the I(wait_for)
+ argument is provided, the module is not returned until the condition is satisfied
+ or the number of retries has been exceeded.
+ required: true
+ wait_for:
+ description:
+ - Specifies what to evaluate from the output of the command and what conditionals
+ to apply. This argument will cause the task to wait for a particular conditional
+ to be true before moving forward. If the conditional is not true by the configured
+ I(retries), the task fails. See examples.
+ aliases:
+ - waitfor
+ match:
+ description:
+ - The I(match) argument is used in conjunction with the I(wait_for) argument to
+ specify the match policy. Valid values are C(all) or C(any). If the value is
+ set to C(all) then all conditionals in the wait_for must be satisfied. If the
+ value is set to C(any) then only one of the values must be satisfied.
+ default: all
+ choices:
+ - any
+ - all
+ retries:
+ description:
+ - Specifies the number of retries a command should be tried before it is considered
+ failed. The command is run on the target device every retry and evaluated against
+ the I(wait_for) conditionals.
+ default: 10
+ interval:
+ description:
+ - Configures the interval in seconds to wait between I(retries) of the command.
+ If the command does not pass the specified conditions, the interval indicates
+ how long to wait before trying the command again.
+ default: 1
+notes:
+- Tested against VyOS 1.1.8 (helium).
+- Running C(show system boot-messages all) will cause the module to hang since VyOS
+ is using a custom pager setting to display the output of that command.
+- If a command sent to the device requires answering a prompt, it is possible to pass
+ a dict containing I(command), I(answer) and I(prompt). See examples.
+- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html).
+"""
+
+EXAMPLES = """
+tasks:
+ - name: show configuration on ethernet devices eth0 and eth1
+ vyos_command:
+ commands:
+ - show interfaces ethernet {{ item }}
+ with_items:
+ - eth0
+ - eth1
+
+ - name: run multiple commands and check if version output contains specific version string
+ vyos_command:
+ commands:
+ - show version
+ - show hardware cpu
+ wait_for:
+ - "result[0] contains 'VyOS 1.1.7'"
+
+ - name: run command that requires answering a prompt
+ vyos_command:
+ commands:
+ - command: 'rollback 1'
+ prompt: 'Proceed with reboot? [confirm][y]'
+ answer: y
+"""
+
+RETURN = """
+stdout:
+ description: The set of responses from the commands
+ returned: always apart from low level errors (such as action plugin)
+ type: list
+ sample: ['...', '...']
+stdout_lines:
+ description: The value of stdout split into a list
+ returned: always
+ type: list
+ sample: [['...', '...'], ['...'], ['...']]
+failed_conditions:
+ description: The list of conditionals that have failed
+ returned: failed
+ type: list
+ sample: ['...', '...']
+warnings:
+ description: The list of warnings (if any) generated by module based on arguments
+ returned: always
+ type: list
+ sample: ['...', '...']
+"""
+import time
+
+from ansible.module_utils._text import to_text
+from ansible.module_utils.basic import AnsibleModule
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import (
+ Conditional,
+)
+from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
+ transform_commands,
+ to_lines,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import (
+ run_commands,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import (
+ vyos_argument_spec,
+)
+
+
+def parse_commands(module, warnings):
+ commands = transform_commands(module)
+
+ if module.check_mode:
+ for item in list(commands):
+ if not item["command"].startswith("show"):
+ warnings.append(
+ "Only show commands are supported when using check mode, not "
+ "executing %s" % item["command"]
+ )
+ commands.remove(item)
+
+ return commands
+
+
+def main():
+ spec = dict(
+ commands=dict(type="list", required=True),
+ wait_for=dict(type="list", aliases=["waitfor"]),
+ match=dict(default="all", choices=["all", "any"]),
+ retries=dict(default=10, type="int"),
+ interval=dict(default=1, type="int"),
+ )
+
+ spec.update(vyos_argument_spec)
+
+ module = AnsibleModule(argument_spec=spec, supports_check_mode=True)
+
+ warnings = list()
+ result = {"changed": False, "warnings": warnings}
+ commands = parse_commands(module, warnings)
+ wait_for = module.params["wait_for"] or list()
+
+ try:
+ conditionals = [Conditional(c) for c in wait_for]
+ except AttributeError as exc:
+ module.fail_json(msg=to_text(exc))
+
+ retries = module.params["retries"]
+ interval = module.params["interval"]
+ match = module.params["match"]
+
+ for _ in range(retries):
+ responses = run_commands(module, commands)
+
+ for item in list(conditionals):
+ if item(responses):
+ if match == "any":
+ conditionals = list()
+ break
+ conditionals.remove(item)
+
+ if not conditionals:
+ break
+
+ time.sleep(interval)
+
+ if conditionals:
+ failed_conditions = [item.raw for item in conditionals]
+ msg = "One or more conditional statements have not been satisfied"
+ module.fail_json(msg=msg, failed_conditions=failed_conditions)
+
+ result.update(
+ {"stdout": responses, "stdout_lines": list(to_lines(responses)),}
+ )
+
+ module.exit_json(**result)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py
new file mode 100644
index 00000000..b899045a
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py
@@ -0,0 +1,354 @@
+#!/usr/bin/python
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+
+ANSIBLE_METADATA = {
+ "metadata_version": "1.1",
+ "status": ["preview"],
+ "supported_by": "network",
+}
+
+
+DOCUMENTATION = """module: vyos_config
+author: Nathaniel Case (@Qalthos)
+short_description: Manage VyOS configuration on remote device
+description:
+- This module provides configuration file management of VyOS devices. It provides
+ arguments for managing both the configuration file and state of the active configuration.
+ All configuration statements are based on `set` and `delete` commands in the device
+ configuration.
+extends_documentation_fragment:
+- vyos.vyos.vyos
+notes:
+- Tested against VyOS 1.1.8 (helium).
+- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html).
+options:
+ lines:
+ description:
+ - The ordered set of configuration lines to be managed and compared with the existing
+ configuration on the remote device.
+ src:
+ description:
+ - The C(src) argument specifies the path to the source config file to load. The
+ source config file can either be in bracket format or set format. The source
+ file can include Jinja2 template variables.
+ match:
+ description:
+ - The C(match) argument controls the method used to match against the current
+ active configuration. By default, the desired config is matched against the
+ active config and the deltas are loaded. If the C(match) argument is set to
+ C(none) the active configuration is ignored and the configuration is always
+ loaded.
+ default: line
+ choices:
+ - line
+ - none
+ backup:
+ description:
+ - The C(backup) argument will backup the current devices active configuration
+ to the Ansible control host prior to making any changes. If the C(backup_options)
+ value is not given, the backup file will be located in the backup folder in
+ the playbook root directory or role root directory, if playbook is part of an
+ ansible role. If the directory does not exist, it is created.
+ type: bool
+ default: 'no'
+ comment:
+ description:
+ - Allows a commit description to be specified to be included when the configuration
+ is committed. If the configuration is not changed or committed, this argument
+ is ignored.
+ default: configured by vyos_config
+ config:
+ description:
+ - The C(config) argument specifies the base configuration to use to compare against
+ the desired configuration. If this value is not specified, the module will
+ automatically retrieve the current active configuration from the remote device.
+ save:
+ description:
+ - The C(save) argument controls whether or not changes made to the active configuration
+ are saved to disk. This is independent of committing the config. When set
+ to True, the active configuration is saved.
+ type: bool
+ default: 'no'
+ backup_options:
+ description:
+ - This is a dict object containing configurable options related to backup file
+ path. The value of this option is read only when C(backup) is set to I(yes),
+ if C(backup) is set to I(no) this option will be silently ignored.
+ suboptions:
+ filename:
+ description:
+ - The filename to be used to store the backup configuration. If the filename
+ is not given it will be generated based on the hostname, current time and
+ date in format defined by <hostname>_config.<current-date>@<current-time>
+ dir_path:
+ description:
+ - This option provides the path ending with directory name in which the backup
+ configuration file will be stored. If the directory does not exist it will
+ be first created and the filename is either the value of C(filename) or
+ default filename as described in C(filename) options description. If the
+ path value is not given in that case a I(backup) directory will be created
+ in the current working directory and backup configuration will be copied
+ in C(filename) within I(backup) directory.
+ type: path
+ type: dict
+"""
+
+EXAMPLES = """
+- name: configure the remote device
+ vyos_config:
+ lines:
+ - set system host-name {{ inventory_hostname }}
+ - set service lldp
+ - delete service dhcp-server
+
+- name: backup and load from file
+ vyos_config:
+ src: vyos.cfg
+ backup: yes
+
+- name: render a Jinja2 template onto the VyOS router
+ vyos_config:
+ src: vyos_template.j2
+
+- name: for idempotency, use full-form commands
+ vyos_config:
+ lines:
+ # - set int eth eth2 description 'OUTSIDE'
+ - set interface ethernet eth2 description 'OUTSIDE'
+
+- name: configurable backup path
+ vyos_config:
+ backup: yes
+ backup_options:
+ filename: backup.cfg
+ dir_path: /home/user
+"""
+
+RETURN = """
+commands:
+ description: The list of configuration commands sent to the device
+ returned: always
+ type: list
+ sample: ['...', '...']
+filtered:
+ description: The list of configuration commands removed to avoid a load failure
+ returned: always
+ type: list
+ sample: ['...', '...']
+backup_path:
+ description: The full path to the backup file
+ returned: when backup is yes
+ type: str
+ sample: /playbooks/ansible/backup/vyos_config.2016-07-16@22:28:34
+filename:
+ description: The name of the backup file
+ returned: when backup is yes and filename is not specified in backup options
+ type: str
+ sample: vyos_config.2016-07-16@22:28:34
+shortname:
+ description: The full path to the backup file excluding the timestamp
+ returned: when backup is yes and filename is not specified in backup options
+ type: str
+ sample: /playbooks/ansible/backup/vyos_config
+date:
+ description: The date extracted from the backup file name
+ returned: when backup is yes
+ type: str
+ sample: "2016-07-16"
+time:
+ description: The time extracted from the backup file name
+ returned: when backup is yes
+ type: str
+ sample: "22:28:34"
+"""
+import re
+
+from ansible.module_utils._text import to_text
+from ansible.module_utils.basic import AnsibleModule
+from ansible.module_utils.connection import ConnectionError
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import (
+ load_config,
+ get_config,
+ run_commands,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import (
+ vyos_argument_spec,
+ get_connection,
+)
+
+
+DEFAULT_COMMENT = "configured by vyos_config"
+
+CONFIG_FILTERS = [
+ re.compile(r"set system login user \S+ authentication encrypted-password")
+]
+
+
+def get_candidate(module):
+ contents = module.params["src"] or module.params["lines"]
+
+ if module.params["src"]:
+ contents = format_commands(contents.splitlines())
+
+ contents = "\n".join(contents)
+ return contents
+
+
+def format_commands(commands):
+ """
+ This function format the input commands and removes the prepend white spaces
+ for command lines having 'set' or 'delete' and it skips empty lines.
+ :param commands:
+ :return: list of commands
+ """
+ return [
+ line.strip() if line.split()[0] in ("set", "delete") else line
+ for line in commands
+ if len(line.strip()) > 0
+ ]
+
+
+def diff_config(commands, config):
+ config = [str(c).replace("'", "") for c in config.splitlines()]
+
+ updates = list()
+ visited = set()
+
+ for line in commands:
+ item = str(line).replace("'", "")
+
+ if not item.startswith("set") and not item.startswith("delete"):
+ raise ValueError("line must start with either `set` or `delete`")
+
+ elif item.startswith("set") and item not in config:
+ updates.append(line)
+
+ elif item.startswith("delete"):
+ if not config:
+ updates.append(line)
+ else:
+ item = re.sub(r"delete", "set", item)
+ for entry in config:
+ if entry.startswith(item) and line not in visited:
+ updates.append(line)
+ visited.add(line)
+
+ return list(updates)
+
+
+def sanitize_config(config, result):
+ result["filtered"] = list()
+ index_to_filter = list()
+ for regex in CONFIG_FILTERS:
+ for index, line in enumerate(list(config)):
+ if regex.search(line):
+ result["filtered"].append(line)
+ index_to_filter.append(index)
+ # Delete all filtered configs
+ for filter_index in sorted(index_to_filter, reverse=True):
+ del config[filter_index]
+
+
+def run(module, result):
+ # get the current active config from the node or passed in via
+ # the config param
+ config = module.params["config"] or get_config(module)
+
+ # create the candidate config object from the arguments
+ candidate = get_candidate(module)
+
+ # create loadable config that includes only the configuration updates
+ connection = get_connection(module)
+ try:
+ response = connection.get_diff(
+ candidate=candidate,
+ running=config,
+ diff_match=module.params["match"],
+ )
+ except ConnectionError as exc:
+ module.fail_json(msg=to_text(exc, errors="surrogate_then_replace"))
+
+ commands = response.get("config_diff")
+ sanitize_config(commands, result)
+
+ result["commands"] = commands
+
+ commit = not module.check_mode
+ comment = module.params["comment"]
+
+ diff = None
+ if commands:
+ diff = load_config(module, commands, commit=commit, comment=comment)
+
+ if result.get("filtered"):
+ result["warnings"].append(
+ "Some configuration commands were "
+ "removed, please see the filtered key"
+ )
+
+ result["changed"] = True
+
+ if module._diff:
+ result["diff"] = {"prepared": diff}
+
+
+def main():
+ backup_spec = dict(filename=dict(), dir_path=dict(type="path"))
+ argument_spec = dict(
+ src=dict(type="path"),
+ lines=dict(type="list"),
+ match=dict(default="line", choices=["line", "none"]),
+ comment=dict(default=DEFAULT_COMMENT),
+ config=dict(),
+ backup=dict(type="bool", default=False),
+ backup_options=dict(type="dict", options=backup_spec),
+ save=dict(type="bool", default=False),
+ )
+
+ argument_spec.update(vyos_argument_spec)
+
+ mutually_exclusive = [("lines", "src")]
+
+ module = AnsibleModule(
+ argument_spec=argument_spec,
+ mutually_exclusive=mutually_exclusive,
+ supports_check_mode=True,
+ )
+
+ warnings = list()
+
+ result = dict(changed=False, warnings=warnings)
+
+ if module.params["backup"]:
+ result["__backup__"] = get_config(module=module)
+
+ if any((module.params["src"], module.params["lines"])):
+ run(module, result)
+
+ if module.params["save"]:
+ diff = run_commands(module, commands=["configure", "compare saved"])[1]
+ if diff != "[edit]":
+ run_commands(module, commands=["save"])
+ result["changed"] = True
+ run_commands(module, commands=["exit"])
+
+ module.exit_json(**result)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py
new file mode 100644
index 00000000..19fb727f
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py
@@ -0,0 +1,174 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+"""
+The module file for vyos_facts
+"""
+
+
+ANSIBLE_METADATA = {
+ "metadata_version": "1.1",
+ "status": [u"preview"],
+ "supported_by": "network",
+}
+
+
+DOCUMENTATION = """module: vyos_facts
+short_description: Get facts about vyos devices.
+description:
+- Collects facts from network devices running the vyos operating system. This module
+ places the facts gathered in the fact tree keyed by the respective resource name. The
+ facts module will always collect a base set of facts from the device and can enable
+ or disable collection of additional facts.
+author:
+- Nathaniel Case (@qalthos)
+- Nilashish Chakraborty (@Nilashishc)
+- Rohit Thakur (@rohitthakur2590)
+extends_documentation_fragment:
+- vyos.vyos.vyos
+notes:
+- Tested against VyOS 1.1.8 (helium).
+- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html).
+options:
+ gather_subset:
+ description:
+ - When supplied, this argument will restrict the facts collected to a given subset. Possible
+ values for this argument include all, default, config, and neighbors. Can specify
+ a list of values to include a larger subset. Values can also be used with an
+ initial C(M(!)) to specify that a specific subset should not be collected.
+ required: false
+ default: '!config'
+ gather_network_resources:
+ description:
+ - When supplied, this argument will restrict the facts collected to a given subset.
+ Possible values for this argument include all and the resources like interfaces.
+ Can specify a list of values to include a larger subset. Values can also be
+ used with an initial C(M(!)) to specify that a specific subset should not be
+ collected. Valid subsets are 'all', 'interfaces', 'l3_interfaces', 'lag_interfaces',
+ 'lldp_global', 'lldp_interfaces', 'static_routes', 'firewall_rules'.
+ required: false
+"""
+
+EXAMPLES = """
+# Gather all facts
+- vyos_facts:
+ gather_subset: all
+ gather_network_resources: all
+
+# collect only the config and default facts
+- vyos_facts:
+ gather_subset: config
+
+# collect everything exception the config
+- vyos_facts:
+ gather_subset: "!config"
+
+# Collect only the interfaces facts
+- vyos_facts:
+ gather_subset:
+ - '!all'
+ - '!min'
+ gather_network_resources:
+ - interfaces
+
+# Do not collect interfaces facts
+- vyos_facts:
+ gather_network_resources:
+ - "!interfaces"
+
+# Collect interfaces and minimal default facts
+- vyos_facts:
+ gather_subset: min
+ gather_network_resources: interfaces
+"""
+
+RETURN = """
+ansible_net_config:
+ description: The running-config from the device
+ returned: when config is configured
+ type: str
+ansible_net_commits:
+ description: The set of available configuration revisions
+ returned: when present
+ type: list
+ansible_net_hostname:
+ description: The configured system hostname
+ returned: always
+ type: str
+ansible_net_model:
+ description: The device model string
+ returned: always
+ type: str
+ansible_net_serialnum:
+ description: The serial number of the device
+ returned: always
+ type: str
+ansible_net_version:
+ description: The version of the software running
+ returned: always
+ type: str
+ansible_net_neighbors:
+ description: The set of LLDP neighbors
+ returned: when interface is configured
+ type: list
+ansible_net_gather_subset:
+ description: The list of subsets gathered by the module
+ returned: always
+ type: list
+ansible_net_api:
+ description: The name of the transport
+ returned: always
+ type: str
+ansible_net_python_version:
+ description: The Python version Ansible controller is using
+ returned: always
+ type: str
+ansible_net_gather_network_resources:
+ description: The list of fact resource subsets collected from the device
+ returned: always
+ type: list
+"""
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.facts.facts import (
+ FactsArgs,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.facts import (
+ Facts,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import (
+ vyos_argument_spec,
+)
+
+
+def main():
+ """
+ Main entry point for module execution
+
+ :returns: ansible_facts
+ """
+ argument_spec = FactsArgs.argument_spec
+ argument_spec.update(vyos_argument_spec)
+
+ module = AnsibleModule(
+ argument_spec=argument_spec, supports_check_mode=True
+ )
+
+ warnings = []
+ if module.params["gather_subset"] == "!config":
+ warnings.append(
+ "default value for `gather_subset` will be changed to `min` from `!config` v2.11 onwards"
+ )
+
+ result = Facts(module).get_facts()
+
+ ansible_facts, additional_warnings = result
+ warnings.extend(additional_warnings)
+
+ module.exit_json(ansible_facts=ansible_facts, warnings=warnings)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py
new file mode 100644
index 00000000..8fe572b0
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py
@@ -0,0 +1,513 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+# Copyright 2019 Red Hat
+# GNU General Public License v3.0+
+# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#############################################
+# WARNING #
+#############################################
+#
+# This file is auto generated by the resource
+# module builder playbook.
+#
+# Do not edit this file manually.
+#
+# Changes to this file will be over written
+# by the resource module builder.
+#
+# Changes should be made in the model used to
+# generate this file or in the resource module
+# builder template.
+#
+#############################################
+
+"""
+The module file for vyos_lldp_interfaces
+"""
+
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+ANSIBLE_METADATA = {
+ "metadata_version": "1.1",
+ "status": ["preview"],
+ "supported_by": "network",
+}
+
+DOCUMENTATION = """module: vyos_lldp_interfaces
+short_description: Manages attributes of lldp interfaces on VyOS devices.
+description: This module manages attributes of lldp interfaces on VyOS network devices.
+notes:
+- Tested against VyOS 1.1.8 (helium).
+- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html).
+author:
+- Rohit Thakur (@rohitthakur2590)
+options:
+ config:
+ description: A list of lldp interfaces configurations.
+ type: list
+ suboptions:
+ name:
+ description:
+ - Name of the lldp interface.
+ type: str
+ required: true
+ enable:
+ description:
+ - to disable lldp on the interface.
+ type: bool
+ default: true
+ location:
+ description:
+ - LLDP-MED location data.
+ type: dict
+ suboptions:
+ civic_based:
+ description:
+ - Civic-based location data.
+ type: dict
+ suboptions:
+ ca_info:
+ description: LLDP-MED address info
+ type: list
+ suboptions:
+ ca_type:
+ description: LLDP-MED Civic Address type.
+ type: int
+ required: true
+ ca_value:
+ description: LLDP-MED Civic Address value.
+ type: str
+ required: true
+ country_code:
+ description: Country Code
+ type: str
+ required: true
+ coordinate_based:
+ description:
+ - Coordinate-based location.
+ type: dict
+ suboptions:
+ altitude:
+ description: Altitude in meters.
+ type: int
+ datum:
+ description: Coordinate datum type.
+ type: str
+ choices:
+ - WGS84
+ - NAD83
+ - MLLW
+ latitude:
+ description: Latitude.
+ type: str
+ required: true
+ longitude:
+ description: Longitude.
+ type: str
+ required: true
+ elin:
+ description: Emergency Call Service ELIN number (between 10-25 numbers).
+ type: str
+ state:
+ description:
+ - The state of the configuration after module completion.
+ type: str
+ choices:
+ - merged
+ - replaced
+ - overridden
+ - deleted
+ default: merged
+"""
+EXAMPLES = """
+# Using merged
+#
+# Before state:
+# -------------
+#
+# vyos@vyos:~$ show configuration commands | grep lldp
+#
+- name: Merge provided configuration with device configuration
+ vyos_lldp_interfaces:
+ config:
+ - name: 'eth1'
+ location:
+ civic_based:
+ country_code: 'US'
+ ca_info:
+ - ca_type: 0
+ ca_value: 'ENGLISH'
+
+ - name: 'eth2'
+ location:
+ coordinate_based:
+ altitude: 2200
+ datum: 'WGS84'
+ longitude: '222.267255W'
+ latitude: '33.524449N'
+ state: merged
+#
+#
+# -------------------------
+# Module Execution Result
+# -------------------------
+#
+# before": []
+#
+# "commands": [
+# "set service lldp interface eth1 location civic-based country-code 'US'",
+# "set service lldp interface eth1 location civic-based ca-type 0 ca-value 'ENGLISH'",
+# "set service lldp interface eth1",
+# "set service lldp interface eth2 location coordinate-based latitude '33.524449N'",
+# "set service lldp interface eth2 location coordinate-based altitude '2200'",
+# "set service lldp interface eth2 location coordinate-based datum 'WGS84'",
+# "set service lldp interface eth2 location coordinate-based longitude '222.267255W'",
+# "set service lldp interface eth2 location coordinate-based latitude '33.524449N'",
+# "set service lldp interface eth2 location coordinate-based altitude '2200'",
+# "set service lldp interface eth2 location coordinate-based datum 'WGS84'",
+# "set service lldp interface eth2 location coordinate-based longitude '222.267255W'",
+# "set service lldp interface eth2"
+#
+# "after": [
+# {
+# "location": {
+# "coordinate_based": {
+# "altitude": 2200,
+# "datum": "WGS84",
+# "latitude": "33.524449N",
+# "longitude": "222.267255W"
+# }
+# },
+# "name": "eth2"
+# },
+# {
+# "location": {
+# "civic_based": {
+# "ca_info": [
+# {
+# "ca_type": 0,
+# "ca_value": "ENGLISH"
+# }
+# ],
+# "country_code": "US"
+# }
+# },
+# "name": "eth1"
+# }
+# ],
+#
+# After state:
+# -------------
+#
+# vyos@vyos:~$ show configuration commands | grep lldp
+# set service lldp interface eth1 location civic-based ca-type 0 ca-value 'ENGLISH'
+# set service lldp interface eth1 location civic-based country-code 'US'
+# set service lldp interface eth2 location coordinate-based altitude '2200'
+# set service lldp interface eth2 location coordinate-based datum 'WGS84'
+# set service lldp interface eth2 location coordinate-based latitude '33.524449N'
+# set service lldp interface eth2 location coordinate-based longitude '222.267255W'
+
+
+# Using replaced
+#
+# Before state:
+# -------------
+#
+# vyos@vyos:~$ show configuration commands | grep lldp
+# set service lldp interface eth1 location civic-based ca-type 0 ca-value 'ENGLISH'
+# set service lldp interface eth1 location civic-based country-code 'US'
+# set service lldp interface eth2 location coordinate-based altitude '2200'
+# set service lldp interface eth2 location coordinate-based datum 'WGS84'
+# set service lldp interface eth2 location coordinate-based latitude '33.524449N'
+# set service lldp interface eth2 location coordinate-based longitude '222.267255W'
+#
+- name: Replace device configurations of listed LLDP interfaces with provided configurations
+ vyos_lldp_interfaces:
+ config:
+ - name: 'eth2'
+ location:
+ civic_based:
+ country_code: 'US'
+ ca_info:
+ - ca_type: 0
+ ca_value: 'ENGLISH'
+
+ - name: 'eth1'
+ location:
+ coordinate_based:
+ altitude: 2200
+ datum: 'WGS84'
+ longitude: '222.267255W'
+ latitude: '33.524449N'
+ state: replaced
+#
+#
+# -------------------------
+# Module Execution Result
+# -------------------------
+#
+# "before": [
+# {
+# "location": {
+# "coordinate_based": {
+# "altitude": 2200,
+# "datum": "WGS84",
+# "latitude": "33.524449N",
+# "longitude": "222.267255W"
+# }
+# },
+# "name": "eth2"
+# },
+# {
+# "location": {
+# "civic_based": {
+# "ca_info": [
+# {
+# "ca_type": 0,
+# "ca_value": "ENGLISH"
+# }
+# ],
+# "country_code": "US"
+# }
+# },
+# "name": "eth1"
+# }
+# ]
+#
+# "commands": [
+# "delete service lldp interface eth2 location",
+# "set service lldp interface eth2 'disable'",
+# "set service lldp interface eth2 location civic-based country-code 'US'",
+# "set service lldp interface eth2 location civic-based ca-type 0 ca-value 'ENGLISH'",
+# "delete service lldp interface eth1 location",
+# "set service lldp interface eth1 'disable'",
+# "set service lldp interface eth1 location coordinate-based latitude '33.524449N'",
+# "set service lldp interface eth1 location coordinate-based altitude '2200'",
+# "set service lldp interface eth1 location coordinate-based datum 'WGS84'",
+# "set service lldp interface eth1 location coordinate-based longitude '222.267255W'"
+# ]
+#
+# "after": [
+# {
+# "location": {
+# "civic_based": {
+# "ca_info": [
+# {
+# "ca_type": 0,
+# "ca_value": "ENGLISH"
+# }
+# ],
+# "country_code": "US"
+# }
+# },
+# "name": "eth2"
+# },
+# {
+# "location": {
+# "coordinate_based": {
+# "altitude": 2200,
+# "datum": "WGS84",
+# "latitude": "33.524449N",
+# "longitude": "222.267255W"
+# }
+# },
+# "name": "eth1"
+# }
+# ]
+#
+# After state:
+# -------------
+#
+# vyos@vyos:~$ show configuration commands | grep lldp
+# set service lldp interface eth1 'disable'
+# set service lldp interface eth1 location coordinate-based altitude '2200'
+# set service lldp interface eth1 location coordinate-based datum 'WGS84'
+# set service lldp interface eth1 location coordinate-based latitude '33.524449N'
+# set service lldp interface eth1 location coordinate-based longitude '222.267255W'
+# set service lldp interface eth2 'disable'
+# set service lldp interface eth2 location civic-based ca-type 0 ca-value 'ENGLISH'
+# set service lldp interface eth2 location civic-based country-code 'US'
+
+
+# Using overridden
+#
+# Before state
+# --------------
+#
+# vyos@vyos:~$ show configuration commands | grep lldp
+# set service lldp interface eth1 'disable'
+# set service lldp interface eth1 location coordinate-based altitude '2200'
+# set service lldp interface eth1 location coordinate-based datum 'WGS84'
+# set service lldp interface eth1 location coordinate-based latitude '33.524449N'
+# set service lldp interface eth1 location coordinate-based longitude '222.267255W'
+# set service lldp interface eth2 'disable'
+# set service lldp interface eth2 location civic-based ca-type 0 ca-value 'ENGLISH'
+# set service lldp interface eth2 location civic-based country-code 'US'
+#
+- name: Overrides all device configuration with provided configuration
+ vyos_lag_interfaces:
+ config:
+ - name: 'eth2'
+ location:
+ elin: 0000000911
+
+ state: overridden
+#
+#
+# -------------------------
+# Module Execution Result
+# -------------------------
+#
+# "before": [
+# {
+# "enable": false,
+# "location": {
+# "civic_based": {
+# "ca_info": [
+# {
+# "ca_type": 0,
+# "ca_value": "ENGLISH"
+# }
+# ],
+# "country_code": "US"
+# }
+# },
+# "name": "eth2"
+# },
+# {
+# "enable": false,
+# "location": {
+# "coordinate_based": {
+# "altitude": 2200,
+# "datum": "WGS84",
+# "latitude": "33.524449N",
+# "longitude": "222.267255W"
+# }
+# },
+# "name": "eth1"
+# }
+# ]
+#
+# "commands": [
+# "delete service lldp interface eth2 location",
+# "delete service lldp interface eth2 disable",
+# "set service lldp interface eth2 location elin 0000000911"
+#
+#
+# "after": [
+# {
+# "location": {
+# "elin": 0000000911
+# },
+# "name": "eth2"
+# }
+# ]
+#
+#
+# After state
+# ------------
+#
+# vyos@vyos# run show configuration commands | grep lldp
+# set service lldp interface eth2 location elin '0000000911'
+
+
+# Using deleted
+#
+# Before state
+# -------------
+#
+# vyos@vyos# run show configuration commands | grep lldp
+# set service lldp interface eth2 location elin '0000000911'
+#
+- name: Delete lldp interface attributes of given interfaces.
+ vyos_lag_interfaces:
+ config:
+ - name: 'eth2'
+ state: deleted
+#
+#
+# ------------------------
+# Module Execution Results
+# ------------------------
+#
+ "before": [
+ {
+ "location": {
+ "elin": 0000000911
+ },
+ "name": "eth2"
+ }
+ ]
+# "commands": [
+# "commands": [
+# "delete service lldp interface eth2"
+# ]
+#
+# "after": []
+# After state
+# ------------
+# vyos@vyos# run show configuration commands | grep lldp
+# set service 'lldp'
+
+
+"""
+RETURN = """
+before:
+ description: The configuration as structured data prior to module invocation.
+ returned: always
+ type: list
+ sample: >
+ The configuration returned will always be in the same format
+ of the parameters above.
+after:
+ description: The configuration as structured data after module completion.
+ returned: when changed
+ type: list
+ sample: >
+ The configuration returned will always be in the same format
+ of the parameters above.
+commands:
+ description: The set of commands pushed to the remote device.
+ returned: always
+ type: list
+ sample:
+ - "set service lldp interface eth2 'disable'"
+ - "delete service lldp interface eth1 location"
+"""
+
+
+from ansible.module_utils.basic import AnsibleModule
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lldp_interfaces.lldp_interfaces import (
+ Lldp_interfacesArgs,
+)
+from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.config.lldp_interfaces.lldp_interfaces import (
+ Lldp_interfaces,
+)
+
+
+def main():
+ """
+ Main entry point for module execution
+
+ :returns: the result form module invocation
+ """
+ required_if = [
+ ("state", "merged", ("config",)),
+ ("state", "replaced", ("config",)),
+ ("state", "overridden", ("config",)),
+ ]
+ module = AnsibleModule(
+ argument_spec=Lldp_interfacesArgs.argument_spec,
+ required_if=required_if,
+ supports_check_mode=True,
+ )
+
+ result = Lldp_interfaces(module).execute_module()
+ module.exit_json(**result)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py
new file mode 100644
index 00000000..fe7712f6
--- /dev/null
+++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py
@@ -0,0 +1,53 @@
+#
+# (c) 2016 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+from __future__ import absolute_import, division, print_function
+
+__metaclass__ = type
+
+import os
+import re
+
+from ansible.plugins.terminal import TerminalBase
+from ansible.errors import AnsibleConnectionFailure
+
+
+class TerminalModule(TerminalBase):
+
+ terminal_stdout_re = [
+ re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
+ re.compile(br"\@[\w\-\.]+:\S+?[>#\$] ?$"),
+ ]
+
+ terminal_stderr_re = [
+ re.compile(br"\n\s*Invalid command:"),
+ re.compile(br"\nCommit failed"),
+ re.compile(br"\n\s+Set failed"),
+ ]
+
+ terminal_length = os.getenv("ANSIBLE_VYOS_TERMINAL_LENGTH", 10000)
+
+ def on_open_shell(self):
+ try:
+ for cmd in (b"set terminal length 0", b"set terminal width 512"):
+ self._exec_cli_command(cmd)
+ self._exec_cli_command(
+ b"set terminal length %d" % self.terminal_length
+ )
+ except AnsibleConnectionFailure:
+ raise AnsibleConnectionFailure("unable to set terminal parameters")
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py
new file mode 120000
index 00000000..0364d766
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py
@@ -0,0 +1 @@
+../../../../../../plugins/action/win_copy.py \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps1
new file mode 120000
index 00000000..6fc438d6
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps1
@@ -0,0 +1 @@
+../../../../../../plugins/modules/async_status.ps1 \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps1
new file mode 120000
index 00000000..81d8afa3
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps1
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_acl.ps1 \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py
new file mode 120000
index 00000000..3a2434cf
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_acl.py \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps1
new file mode 120000
index 00000000..a34fb012
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps1
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_copy.ps1 \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py
new file mode 120000
index 00000000..2d2c69a2
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_copy.py \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps1
new file mode 120000
index 00000000..8ee5c2b5
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps1
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_file.ps1 \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py
new file mode 120000
index 00000000..b4bc0583
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_file.py \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps1
new file mode 120000
index 00000000..d7b25ed0
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps1
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_ping.ps1 \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py
new file mode 120000
index 00000000..0b97c87b
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_ping.py \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps1
new file mode 120000
index 00000000..eb07a017
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps1
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_shell.ps1 \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py
new file mode 120000
index 00000000..3c6f0749
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_shell.py \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps1
new file mode 120000
index 00000000..62a7a40a
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps1
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_stat.ps1 \ No newline at end of file
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py
new file mode 120000
index 00000000..1db4c95e
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py
@@ -0,0 +1 @@
+../../../../../../plugins/modules/win_stat.py \ No newline at end of file
diff --git a/test/support/windows-integration/plugins/action/win_copy.py b/test/support/windows-integration/plugins/action/win_copy.py
new file mode 100644
index 00000000..adb918be
--- /dev/null
+++ b/test/support/windows-integration/plugins/action/win_copy.py
@@ -0,0 +1,522 @@
+# This file is part of Ansible
+
+# Copyright (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import base64
+import json
+import os
+import os.path
+import shutil
+import tempfile
+import traceback
+import zipfile
+
+from ansible import constants as C
+from ansible.errors import AnsibleError, AnsibleFileNotFound
+from ansible.module_utils._text import to_bytes, to_native, to_text
+from ansible.module_utils.parsing.convert_bool import boolean
+from ansible.plugins.action import ActionBase
+from ansible.utils.hashing import checksum
+
+
+def _walk_dirs(topdir, loader, decrypt=True, base_path=None, local_follow=False, trailing_slash_detector=None, checksum_check=False):
+ """
+ Walk a filesystem tree returning enough information to copy the files.
+ This is similar to the _walk_dirs function in ``copy.py`` but returns
+ a dict instead of a tuple for each entry and includes the checksum of
+ a local file if wanted.
+
+ :arg topdir: The directory that the filesystem tree is rooted at
+ :arg loader: The self._loader object from ActionBase
+ :kwarg decrypt: Whether to decrypt a file encrypted with ansible-vault
+ :kwarg base_path: The initial directory structure to strip off of the
+ files for the destination directory. If this is None (the default),
+ the base_path is set to ``top_dir``.
+ :kwarg local_follow: Whether to follow symlinks on the source. When set
+ to False, no symlinks are dereferenced. When set to True (the
+ default), the code will dereference most symlinks. However, symlinks
+ can still be present if needed to break a circular link.
+ :kwarg trailing_slash_detector: Function to determine if a path has
+ a trailing directory separator. Only needed when dealing with paths on
+ a remote machine (in which case, pass in a function that is aware of the
+ directory separator conventions on the remote machine).
+ :kawrg whether to get the checksum of the local file and add to the dict
+ :returns: dictionary of dictionaries. All of the path elements in the structure are text string.
+ This separates all the files, directories, and symlinks along with
+ import information about each::
+
+ {
+ 'files'; [{
+ src: '/absolute/path/to/copy/from',
+ dest: 'relative/path/to/copy/to',
+ checksum: 'b54ba7f5621240d403f06815f7246006ef8c7d43'
+ }, ...],
+ 'directories'; [{
+ src: '/absolute/path/to/copy/from',
+ dest: 'relative/path/to/copy/to'
+ }, ...],
+ 'symlinks'; [{
+ src: '/symlink/target/path',
+ dest: 'relative/path/to/copy/to'
+ }, ...],
+
+ }
+
+ The ``symlinks`` field is only populated if ``local_follow`` is set to False
+ *or* a circular symlink cannot be dereferenced. The ``checksum`` entry is set
+ to None if checksum_check=False.
+
+ """
+ # Convert the path segments into byte strings
+
+ r_files = {'files': [], 'directories': [], 'symlinks': []}
+
+ def _recurse(topdir, rel_offset, parent_dirs, rel_base=u'', checksum_check=False):
+ """
+ This is a closure (function utilizing variables from it's parent
+ function's scope) so that we only need one copy of all the containers.
+ Note that this function uses side effects (See the Variables used from
+ outer scope).
+
+ :arg topdir: The directory we are walking for files
+ :arg rel_offset: Integer defining how many characters to strip off of
+ the beginning of a path
+ :arg parent_dirs: Directories that we're copying that this directory is in.
+ :kwarg rel_base: String to prepend to the path after ``rel_offset`` is
+ applied to form the relative path.
+
+ Variables used from the outer scope
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ :r_files: Dictionary of files in the hierarchy. See the return value
+ for :func:`walk` for the structure of this dictionary.
+ :local_follow: Read-only inside of :func:`_recurse`. Whether to follow symlinks
+ """
+ for base_path, sub_folders, files in os.walk(topdir):
+ for filename in files:
+ filepath = os.path.join(base_path, filename)
+ dest_filepath = os.path.join(rel_base, filepath[rel_offset:])
+
+ if os.path.islink(filepath):
+ # Dereference the symlnk
+ real_file = loader.get_real_file(os.path.realpath(filepath), decrypt=decrypt)
+ if local_follow and os.path.isfile(real_file):
+ # Add the file pointed to by the symlink
+ r_files['files'].append(
+ {
+ "src": real_file,
+ "dest": dest_filepath,
+ "checksum": _get_local_checksum(checksum_check, real_file)
+ }
+ )
+ else:
+ # Mark this file as a symlink to copy
+ r_files['symlinks'].append({"src": os.readlink(filepath), "dest": dest_filepath})
+ else:
+ # Just a normal file
+ real_file = loader.get_real_file(filepath, decrypt=decrypt)
+ r_files['files'].append(
+ {
+ "src": real_file,
+ "dest": dest_filepath,
+ "checksum": _get_local_checksum(checksum_check, real_file)
+ }
+ )
+
+ for dirname in sub_folders:
+ dirpath = os.path.join(base_path, dirname)
+ dest_dirpath = os.path.join(rel_base, dirpath[rel_offset:])
+ real_dir = os.path.realpath(dirpath)
+ dir_stats = os.stat(real_dir)
+
+ if os.path.islink(dirpath):
+ if local_follow:
+ if (dir_stats.st_dev, dir_stats.st_ino) in parent_dirs:
+ # Just insert the symlink if the target directory
+ # exists inside of the copy already
+ r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath})
+ else:
+ # Walk the dirpath to find all parent directories.
+ new_parents = set()
+ parent_dir_list = os.path.dirname(dirpath).split(os.path.sep)
+ for parent in range(len(parent_dir_list), 0, -1):
+ parent_stat = os.stat(u'/'.join(parent_dir_list[:parent]))
+ if (parent_stat.st_dev, parent_stat.st_ino) in parent_dirs:
+ # Reached the point at which the directory
+ # tree is already known. Don't add any
+ # more or we might go to an ancestor that
+ # isn't being copied.
+ break
+ new_parents.add((parent_stat.st_dev, parent_stat.st_ino))
+
+ if (dir_stats.st_dev, dir_stats.st_ino) in new_parents:
+ # This was a a circular symlink. So add it as
+ # a symlink
+ r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath})
+ else:
+ # Walk the directory pointed to by the symlink
+ r_files['directories'].append({"src": real_dir, "dest": dest_dirpath})
+ offset = len(real_dir) + 1
+ _recurse(real_dir, offset, parent_dirs.union(new_parents),
+ rel_base=dest_dirpath,
+ checksum_check=checksum_check)
+ else:
+ # Add the symlink to the destination
+ r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath})
+ else:
+ # Just a normal directory
+ r_files['directories'].append({"src": dirpath, "dest": dest_dirpath})
+
+ # Check if the source ends with a "/" so that we know which directory
+ # level to work at (similar to rsync)
+ source_trailing_slash = False
+ if trailing_slash_detector:
+ source_trailing_slash = trailing_slash_detector(topdir)
+ else:
+ source_trailing_slash = topdir.endswith(os.path.sep)
+
+ # Calculate the offset needed to strip the base_path to make relative
+ # paths
+ if base_path is None:
+ base_path = topdir
+ if not source_trailing_slash:
+ base_path = os.path.dirname(base_path)
+ if topdir.startswith(base_path):
+ offset = len(base_path)
+
+ # Make sure we're making the new paths relative
+ if trailing_slash_detector and not trailing_slash_detector(base_path):
+ offset += 1
+ elif not base_path.endswith(os.path.sep):
+ offset += 1
+
+ if os.path.islink(topdir) and not local_follow:
+ r_files['symlinks'] = {"src": os.readlink(topdir), "dest": os.path.basename(topdir)}
+ return r_files
+
+ dir_stats = os.stat(topdir)
+ parents = frozenset(((dir_stats.st_dev, dir_stats.st_ino),))
+ # Actually walk the directory hierarchy
+ _recurse(topdir, offset, parents, checksum_check=checksum_check)
+
+ return r_files
+
+
+def _get_local_checksum(get_checksum, local_path):
+ if get_checksum:
+ return checksum(local_path)
+ else:
+ return None
+
+
+class ActionModule(ActionBase):
+
+ WIN_PATH_SEPARATOR = "\\"
+
+ def _create_content_tempfile(self, content):
+ ''' Create a tempfile containing defined content '''
+ fd, content_tempfile = tempfile.mkstemp(dir=C.DEFAULT_LOCAL_TMP)
+ f = os.fdopen(fd, 'wb')
+ content = to_bytes(content)
+ try:
+ f.write(content)
+ except Exception as err:
+ os.remove(content_tempfile)
+ raise Exception(err)
+ finally:
+ f.close()
+ return content_tempfile
+
+ def _create_zip_tempfile(self, files, directories):
+ tmpdir = tempfile.mkdtemp(dir=C.DEFAULT_LOCAL_TMP)
+ zip_file_path = os.path.join(tmpdir, "win_copy.zip")
+ zip_file = zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_STORED, True)
+
+ # encoding the file/dir name with base64 so Windows can unzip a unicode
+ # filename and get the right name, Windows doesn't handle unicode names
+ # very well
+ for directory in directories:
+ directory_path = to_bytes(directory['src'], errors='surrogate_or_strict')
+ archive_path = to_bytes(directory['dest'], errors='surrogate_or_strict')
+
+ encoded_path = to_text(base64.b64encode(archive_path), errors='surrogate_or_strict')
+ zip_file.write(directory_path, encoded_path, zipfile.ZIP_DEFLATED)
+
+ for file in files:
+ file_path = to_bytes(file['src'], errors='surrogate_or_strict')
+ archive_path = to_bytes(file['dest'], errors='surrogate_or_strict')
+
+ encoded_path = to_text(base64.b64encode(archive_path), errors='surrogate_or_strict')
+ zip_file.write(file_path, encoded_path, zipfile.ZIP_DEFLATED)
+
+ return zip_file_path
+
+ def _remove_tempfile_if_content_defined(self, content, content_tempfile):
+ if content is not None:
+ os.remove(content_tempfile)
+
+ def _copy_single_file(self, local_file, dest, source_rel, task_vars, tmp, backup):
+ if self._play_context.check_mode:
+ module_return = dict(changed=True)
+ return module_return
+
+ # copy the file across to the server
+ tmp_src = self._connection._shell.join_path(tmp, 'source')
+ self._transfer_file(local_file, tmp_src)
+
+ copy_args = self._task.args.copy()
+ copy_args.update(
+ dict(
+ dest=dest,
+ src=tmp_src,
+ _original_basename=source_rel,
+ _copy_mode="single",
+ backup=backup,
+ )
+ )
+ copy_args.pop('content', None)
+
+ copy_result = self._execute_module(module_name="copy",
+ module_args=copy_args,
+ task_vars=task_vars)
+
+ return copy_result
+
+ def _copy_zip_file(self, dest, files, directories, task_vars, tmp, backup):
+ # create local zip file containing all the files and directories that
+ # need to be copied to the server
+ if self._play_context.check_mode:
+ module_return = dict(changed=True)
+ return module_return
+
+ try:
+ zip_file = self._create_zip_tempfile(files, directories)
+ except Exception as e:
+ module_return = dict(
+ changed=False,
+ failed=True,
+ msg="failed to create tmp zip file: %s" % to_text(e),
+ exception=traceback.format_exc()
+ )
+ return module_return
+
+ zip_path = self._loader.get_real_file(zip_file)
+
+ # send zip file to remote, file must end in .zip so
+ # Com Shell.Application works
+ tmp_src = self._connection._shell.join_path(tmp, 'source.zip')
+ self._transfer_file(zip_path, tmp_src)
+
+ # run the explode operation of win_copy on remote
+ copy_args = self._task.args.copy()
+ copy_args.update(
+ dict(
+ src=tmp_src,
+ dest=dest,
+ _copy_mode="explode",
+ backup=backup,
+ )
+ )
+ copy_args.pop('content', None)
+ module_return = self._execute_module(module_name='copy',
+ module_args=copy_args,
+ task_vars=task_vars)
+ shutil.rmtree(os.path.dirname(zip_path))
+ return module_return
+
+ def run(self, tmp=None, task_vars=None):
+ ''' handler for file transfer operations '''
+ if task_vars is None:
+ task_vars = dict()
+
+ result = super(ActionModule, self).run(tmp, task_vars)
+ del tmp # tmp no longer has any effect
+
+ source = self._task.args.get('src', None)
+ content = self._task.args.get('content', None)
+ dest = self._task.args.get('dest', None)
+ remote_src = boolean(self._task.args.get('remote_src', False), strict=False)
+ local_follow = boolean(self._task.args.get('local_follow', False), strict=False)
+ force = boolean(self._task.args.get('force', True), strict=False)
+ decrypt = boolean(self._task.args.get('decrypt', True), strict=False)
+ backup = boolean(self._task.args.get('backup', False), strict=False)
+
+ result['src'] = source
+ result['dest'] = dest
+
+ result['failed'] = True
+ if (source is None and content is None) or dest is None:
+ result['msg'] = "src (or content) and dest are required"
+ elif source is not None and content is not None:
+ result['msg'] = "src and content are mutually exclusive"
+ elif content is not None and dest is not None and (
+ dest.endswith(os.path.sep) or dest.endswith(self.WIN_PATH_SEPARATOR)):
+ result['msg'] = "dest must be a file if content is defined"
+ else:
+ del result['failed']
+
+ if result.get('failed'):
+ return result
+
+ # If content is defined make a temp file and write the content into it
+ content_tempfile = None
+ if content is not None:
+ try:
+ # if content comes to us as a dict it should be decoded json.
+ # We need to encode it back into a string and write it out
+ if isinstance(content, dict) or isinstance(content, list):
+ content_tempfile = self._create_content_tempfile(json.dumps(content))
+ else:
+ content_tempfile = self._create_content_tempfile(content)
+ source = content_tempfile
+ except Exception as err:
+ result['failed'] = True
+ result['msg'] = "could not write content tmp file: %s" % to_native(err)
+ return result
+ # all actions should occur on the remote server, run win_copy module
+ elif remote_src:
+ new_module_args = self._task.args.copy()
+ new_module_args.update(
+ dict(
+ _copy_mode="remote",
+ dest=dest,
+ src=source,
+ force=force,
+ backup=backup,
+ )
+ )
+ new_module_args.pop('content', None)
+ result.update(self._execute_module(module_args=new_module_args, task_vars=task_vars))
+ return result
+ # find_needle returns a path that may not have a trailing slash on a
+ # directory so we need to find that out first and append at the end
+ else:
+ trailing_slash = source.endswith(os.path.sep)
+ try:
+ # find in expected paths
+ source = self._find_needle('files', source)
+ except AnsibleError as e:
+ result['failed'] = True
+ result['msg'] = to_text(e)
+ result['exception'] = traceback.format_exc()
+ return result
+
+ if trailing_slash != source.endswith(os.path.sep):
+ if source[-1] == os.path.sep:
+ source = source[:-1]
+ else:
+ source = source + os.path.sep
+
+ # A list of source file tuples (full_path, relative_path) which will try to copy to the destination
+ source_files = {'files': [], 'directories': [], 'symlinks': []}
+
+ # If source is a directory populate our list else source is a file and translate it to a tuple.
+ if os.path.isdir(to_bytes(source, errors='surrogate_or_strict')):
+ result['operation'] = 'folder_copy'
+
+ # Get a list of the files we want to replicate on the remote side
+ source_files = _walk_dirs(source, self._loader, decrypt=decrypt, local_follow=local_follow,
+ trailing_slash_detector=self._connection._shell.path_has_trailing_slash,
+ checksum_check=force)
+
+ # If it's recursive copy, destination is always a dir,
+ # explicitly mark it so (note - win_copy module relies on this).
+ if not self._connection._shell.path_has_trailing_slash(dest):
+ dest = "%s%s" % (dest, self.WIN_PATH_SEPARATOR)
+
+ check_dest = dest
+ # Source is a file, add details to source_files dict
+ else:
+ result['operation'] = 'file_copy'
+
+ # If the local file does not exist, get_real_file() raises AnsibleFileNotFound
+ try:
+ source_full = self._loader.get_real_file(source, decrypt=decrypt)
+ except AnsibleFileNotFound as e:
+ result['failed'] = True
+ result['msg'] = "could not find src=%s, %s" % (source_full, to_text(e))
+ return result
+
+ original_basename = os.path.basename(source)
+ result['original_basename'] = original_basename
+
+ # check if dest ends with / or \ and append source filename to dest
+ if self._connection._shell.path_has_trailing_slash(dest):
+ check_dest = dest
+ filename = original_basename
+ result['dest'] = self._connection._shell.join_path(dest, filename)
+ else:
+ # replace \\ with / so we can use os.path to get the filename or dirname
+ unix_path = dest.replace(self.WIN_PATH_SEPARATOR, os.path.sep)
+ filename = os.path.basename(unix_path)
+ check_dest = os.path.dirname(unix_path)
+
+ file_checksum = _get_local_checksum(force, source_full)
+ source_files['files'].append(
+ dict(
+ src=source_full,
+ dest=filename,
+ checksum=file_checksum
+ )
+ )
+ result['checksum'] = file_checksum
+ result['size'] = os.path.getsize(to_bytes(source_full, errors='surrogate_or_strict'))
+
+ # find out the files/directories/symlinks that we need to copy to the server
+ query_args = self._task.args.copy()
+ query_args.update(
+ dict(
+ _copy_mode="query",
+ dest=check_dest,
+ force=force,
+ files=source_files['files'],
+ directories=source_files['directories'],
+ symlinks=source_files['symlinks'],
+ )
+ )
+ # src is not required for query, will fail path validation is src has unix allowed chars
+ query_args.pop('src', None)
+
+ query_args.pop('content', None)
+ query_return = self._execute_module(module_args=query_args,
+ task_vars=task_vars)
+
+ if query_return.get('failed') is True:
+ result.update(query_return)
+ return result
+
+ if len(query_return['files']) > 0 or len(query_return['directories']) > 0 and self._connection._shell.tmpdir is None:
+ self._connection._shell.tmpdir = self._make_tmp_path()
+
+ if len(query_return['files']) == 1 and len(query_return['directories']) == 0:
+ # we only need to copy 1 file, don't mess around with zips
+ file_src = query_return['files'][0]['src']
+ file_dest = query_return['files'][0]['dest']
+ result.update(self._copy_single_file(file_src, dest, file_dest,
+ task_vars, self._connection._shell.tmpdir, backup))
+ if result.get('failed') is True:
+ result['msg'] = "failed to copy file %s: %s" % (file_src, result['msg'])
+ result['changed'] = True
+
+ elif len(query_return['files']) > 0 or len(query_return['directories']) > 0:
+ # either multiple files or directories need to be copied, compress
+ # to a zip and 'explode' the zip on the server
+ # TODO: handle symlinks
+ result.update(self._copy_zip_file(dest, source_files['files'],
+ source_files['directories'],
+ task_vars, self._connection._shell.tmpdir, backup))
+ result['changed'] = True
+ else:
+ # no operations need to occur
+ result['failed'] = False
+ result['changed'] = False
+
+ # remove the content tmp file and remote tmp file if it was created
+ self._remove_tempfile_if_content_defined(content, content_tempfile)
+ self._remove_tmp_path(self._connection._shell.tmpdir)
+ return result
diff --git a/test/support/windows-integration/plugins/action/win_reboot.py b/test/support/windows-integration/plugins/action/win_reboot.py
new file mode 100644
index 00000000..c408f4f3
--- /dev/null
+++ b/test/support/windows-integration/plugins/action/win_reboot.py
@@ -0,0 +1,96 @@
+# Copyright: (c) 2018, Matt Davis <mdavis@ansible.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from datetime import datetime
+
+from ansible.errors import AnsibleError
+from ansible.module_utils._text import to_native
+from ansible.plugins.action import ActionBase
+from ansible.plugins.action.reboot import ActionModule as RebootActionModule
+from ansible.utils.display import Display
+
+display = Display()
+
+
+class TimedOutException(Exception):
+ pass
+
+
+class ActionModule(RebootActionModule, ActionBase):
+ TRANSFERS_FILES = False
+ _VALID_ARGS = frozenset((
+ 'connect_timeout', 'connect_timeout_sec', 'msg', 'post_reboot_delay', 'post_reboot_delay_sec', 'pre_reboot_delay', 'pre_reboot_delay_sec',
+ 'reboot_timeout', 'reboot_timeout_sec', 'shutdown_timeout', 'shutdown_timeout_sec', 'test_command',
+ ))
+
+ DEFAULT_BOOT_TIME_COMMAND = "(Get-WmiObject -ClassName Win32_OperatingSystem).LastBootUpTime"
+ DEFAULT_CONNECT_TIMEOUT = 5
+ DEFAULT_PRE_REBOOT_DELAY = 2
+ DEFAULT_SUDOABLE = False
+ DEFAULT_SHUTDOWN_COMMAND_ARGS = '/r /t {delay_sec} /c "{message}"'
+
+ DEPRECATED_ARGS = {
+ 'shutdown_timeout': '2.5',
+ 'shutdown_timeout_sec': '2.5',
+ }
+
+ def __init__(self, *args, **kwargs):
+ super(ActionModule, self).__init__(*args, **kwargs)
+
+ def get_distribution(self, task_vars):
+ return {'name': 'windows', 'version': '', 'family': ''}
+
+ def get_shutdown_command(self, task_vars, distribution):
+ return self.DEFAULT_SHUTDOWN_COMMAND
+
+ def run_test_command(self, distribution, **kwargs):
+ # Need to wrap the test_command in our PowerShell encoded wrapper. This is done to align the command input to a
+ # common shell and to allow the psrp connection plugin to report the correct exit code without manually setting
+ # $LASTEXITCODE for just that plugin.
+ test_command = self._task.args.get('test_command', self.DEFAULT_TEST_COMMAND)
+ kwargs['test_command'] = self._connection._shell._encode_script(test_command)
+ super(ActionModule, self).run_test_command(distribution, **kwargs)
+
+ def perform_reboot(self, task_vars, distribution):
+ shutdown_command = self.get_shutdown_command(task_vars, distribution)
+ shutdown_command_args = self.get_shutdown_command_args(distribution)
+ reboot_command = self._connection._shell._encode_script('{0} {1}'.format(shutdown_command, shutdown_command_args))
+
+ display.vvv("{action}: rebooting server...".format(action=self._task.action))
+ display.debug("{action}: distribution: {dist}".format(action=self._task.action, dist=distribution))
+ display.debug("{action}: rebooting server with command '{command}'".format(action=self._task.action, command=reboot_command))
+
+ result = {}
+ reboot_result = self._low_level_execute_command(reboot_command, sudoable=self.DEFAULT_SUDOABLE)
+ result['start'] = datetime.utcnow()
+
+ # Test for "A system shutdown has already been scheduled. (1190)" and handle it gracefully
+ stdout = reboot_result['stdout']
+ stderr = reboot_result['stderr']
+ if reboot_result['rc'] == 1190 or (reboot_result['rc'] != 0 and "(1190)" in reboot_result['stderr']):
+ display.warning('A scheduled reboot was pre-empted by Ansible.')
+
+ # Try to abort (this may fail if it was already aborted)
+ result1 = self._low_level_execute_command(self._connection._shell._encode_script('shutdown /a'),
+ sudoable=self.DEFAULT_SUDOABLE)
+
+ # Initiate reboot again
+ result2 = self._low_level_execute_command(reboot_command, sudoable=self.DEFAULT_SUDOABLE)
+
+ reboot_result['rc'] = result2['rc']
+ stdout += result1['stdout'] + result2['stdout']
+ stderr += result1['stderr'] + result2['stderr']
+
+ if reboot_result['rc'] != 0:
+ result['failed'] = True
+ result['rebooted'] = False
+ result['msg'] = "Reboot command failed, error was: {stdout} {stderr}".format(
+ stdout=to_native(stdout.strip()),
+ stderr=to_native(stderr.strip()))
+ return result
+
+ result['failed'] = False
+ return result
diff --git a/test/support/windows-integration/plugins/action/win_template.py b/test/support/windows-integration/plugins/action/win_template.py
new file mode 100644
index 00000000..20494b93
--- /dev/null
+++ b/test/support/windows-integration/plugins/action/win_template.py
@@ -0,0 +1,29 @@
+# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from ansible.plugins.action import ActionBase
+from ansible.plugins.action.template import ActionModule as TemplateActionModule
+
+
+# Even though TemplateActionModule inherits from ActionBase, we still need to
+# directly inherit from ActionBase to appease the plugin loader.
+class ActionModule(TemplateActionModule, ActionBase):
+ DEFAULT_NEWLINE_SEQUENCE = '\r\n'
diff --git a/test/support/windows-integration/plugins/become/runas.py b/test/support/windows-integration/plugins/become/runas.py
new file mode 100644
index 00000000..c8ae881c
--- /dev/null
+++ b/test/support/windows-integration/plugins/become/runas.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Copyright: (c) 2018, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+DOCUMENTATION = """
+ become: runas
+ short_description: Run As user
+ description:
+ - This become plugins allows your remote/login user to execute commands as another user via the windows runas facility.
+ author: ansible (@core)
+ version_added: "2.8"
+ options:
+ become_user:
+ description: User you 'become' to execute the task
+ ini:
+ - section: privilege_escalation
+ key: become_user
+ - section: runas_become_plugin
+ key: user
+ vars:
+ - name: ansible_become_user
+ - name: ansible_runas_user
+ env:
+ - name: ANSIBLE_BECOME_USER
+ - name: ANSIBLE_RUNAS_USER
+ required: True
+ become_flags:
+ description: Options to pass to runas, a space delimited list of k=v pairs
+ default: ''
+ ini:
+ - section: privilege_escalation
+ key: become_flags
+ - section: runas_become_plugin
+ key: flags
+ vars:
+ - name: ansible_become_flags
+ - name: ansible_runas_flags
+ env:
+ - name: ANSIBLE_BECOME_FLAGS
+ - name: ANSIBLE_RUNAS_FLAGS
+ become_pass:
+ description: password
+ ini:
+ - section: runas_become_plugin
+ key: password
+ vars:
+ - name: ansible_become_password
+ - name: ansible_become_pass
+ - name: ansible_runas_pass
+ env:
+ - name: ANSIBLE_BECOME_PASS
+ - name: ANSIBLE_RUNAS_PASS
+ notes:
+ - runas is really implemented in the powershell module handler and as such can only be used with winrm connections.
+ - This plugin ignores the 'become_exe' setting as it uses an API and not an executable.
+ - The Secondary Logon service (seclogon) must be running to use runas
+"""
+
+from ansible.plugins.become import BecomeBase
+
+
+class BecomeModule(BecomeBase):
+
+ name = 'runas'
+
+ def build_become_command(self, cmd, shell):
+ # runas is implemented inside the winrm connection plugin
+ return cmd
diff --git a/test/support/windows-integration/plugins/module_utils/Ansible.Service.cs b/test/support/windows-integration/plugins/module_utils/Ansible.Service.cs
new file mode 100644
index 00000000..be0f3db3
--- /dev/null
+++ b/test/support/windows-integration/plugins/module_utils/Ansible.Service.cs
@@ -0,0 +1,1341 @@
+using Microsoft.Win32.SafeHandles;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.ConstrainedExecution;
+using System.Runtime.InteropServices;
+using System.Security.Principal;
+using System.Text;
+using Ansible.Privilege;
+
+namespace Ansible.Service
+{
+ internal class NativeHelpers
+ {
+ [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
+ public struct ENUM_SERVICE_STATUSW
+ {
+ public string lpServiceName;
+ public string lpDisplayName;
+ public SERVICE_STATUS ServiceStatus;
+ }
+
+ [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
+ public struct QUERY_SERVICE_CONFIGW
+ {
+ public ServiceType dwServiceType;
+ public ServiceStartType dwStartType;
+ public ErrorControl dwErrorControl;
+ [MarshalAs(UnmanagedType.LPWStr)] public string lpBinaryPathName;
+ [MarshalAs(UnmanagedType.LPWStr)] public string lpLoadOrderGroup;
+ public Int32 dwTagId;
+ public IntPtr lpDependencies; // Can't rely on marshaling as dependencies are delimited by \0.
+ [MarshalAs(UnmanagedType.LPWStr)] public string lpServiceStartName;
+ [MarshalAs(UnmanagedType.LPWStr)] public string lpDisplayName;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SC_ACTION
+ {
+ public FailureAction Type;
+ public UInt32 Delay;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_DELAYED_AUTO_START_INFO
+ {
+ public bool fDelayedAutostart;
+ }
+
+ [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
+ public struct SERVICE_DESCRIPTIONW
+ {
+ [MarshalAs(UnmanagedType.LPWStr)] public string lpDescription;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_FAILURE_ACTIONS_FLAG
+ {
+ public bool fFailureActionsOnNonCrashFailures;
+ }
+
+ [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
+ public struct SERVICE_FAILURE_ACTIONSW
+ {
+ public UInt32 dwResetPeriod;
+ [MarshalAs(UnmanagedType.LPWStr)] public string lpRebootMsg;
+ [MarshalAs(UnmanagedType.LPWStr)] public string lpCommand;
+ public UInt32 cActions;
+ public IntPtr lpsaActions;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_LAUNCH_PROTECTED_INFO
+ {
+ public LaunchProtection dwLaunchProtected;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_PREFERRED_NODE_INFO
+ {
+ public UInt16 usPreferredNode;
+ public bool fDelete;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_PRESHUTDOWN_INFO
+ {
+ public UInt32 dwPreshutdownTimeout;
+ }
+
+ [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
+ public struct SERVICE_REQUIRED_PRIVILEGES_INFOW
+ {
+ // Can't rely on marshaling as privileges are delimited by \0.
+ public IntPtr pmszRequiredPrivileges;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_SID_INFO
+ {
+ public ServiceSidInfo dwServiceSidType;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_STATUS
+ {
+ public ServiceType dwServiceType;
+ public ServiceStatus dwCurrentState;
+ public ControlsAccepted dwControlsAccepted;
+ public UInt32 dwWin32ExitCode;
+ public UInt32 dwServiceSpecificExitCode;
+ public UInt32 dwCheckPoint;
+ public UInt32 dwWaitHint;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_STATUS_PROCESS
+ {
+ public ServiceType dwServiceType;
+ public ServiceStatus dwCurrentState;
+ public ControlsAccepted dwControlsAccepted;
+ public UInt32 dwWin32ExitCode;
+ public UInt32 dwServiceSpecificExitCode;
+ public UInt32 dwCheckPoint;
+ public UInt32 dwWaitHint;
+ public UInt32 dwProcessId;
+ public ServiceFlags dwServiceFlags;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_TRIGGER
+ {
+ public TriggerType dwTriggerType;
+ public TriggerAction dwAction;
+ public IntPtr pTriggerSubtype;
+ public UInt32 cDataItems;
+ public IntPtr pDataItems;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_TRIGGER_SPECIFIC_DATA_ITEM
+ {
+ public TriggerDataType dwDataType;
+ public UInt32 cbData;
+ public IntPtr pData;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SERVICE_TRIGGER_INFO
+ {
+ public UInt32 cTriggers;
+ public IntPtr pTriggers;
+ public IntPtr pReserved;
+ }
+
+ public enum ConfigInfoLevel : uint
+ {
+ SERVICE_CONFIG_DESCRIPTION = 0x00000001,
+ SERVICE_CONFIG_FAILURE_ACTIONS = 0x00000002,
+ SERVICE_CONFIG_DELAYED_AUTO_START_INFO = 0x00000003,
+ SERVICE_CONFIG_FAILURE_ACTIONS_FLAG = 0x00000004,
+ SERVICE_CONFIG_SERVICE_SID_INFO = 0x00000005,
+ SERVICE_CONFIG_REQUIRED_PRIVILEGES_INFO = 0x00000006,
+ SERVICE_CONFIG_PRESHUTDOWN_INFO = 0x00000007,
+ SERVICE_CONFIG_TRIGGER_INFO = 0x00000008,
+ SERVICE_CONFIG_PREFERRED_NODE = 0x00000009,
+ SERVICE_CONFIG_LAUNCH_PROTECTED = 0x0000000c,
+ }
+ }
+
+ internal class NativeMethods
+ {
+ [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
+ public static extern bool ChangeServiceConfigW(
+ SafeHandle hService,
+ ServiceType dwServiceType,
+ ServiceStartType dwStartType,
+ ErrorControl dwErrorControl,
+ string lpBinaryPathName,
+ string lpLoadOrderGroup,
+ IntPtr lpdwTagId,
+ string lpDependencies,
+ string lpServiceStartName,
+ string lpPassword,
+ string lpDisplayName);
+
+ [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
+ public static extern bool ChangeServiceConfig2W(
+ SafeHandle hService,
+ NativeHelpers.ConfigInfoLevel dwInfoLevel,
+ IntPtr lpInfo);
+
+ [DllImport("Advapi32.dll", SetLastError = true)]
+ public static extern bool CloseServiceHandle(
+ IntPtr hSCObject);
+
+ [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
+ public static extern SafeServiceHandle CreateServiceW(
+ SafeHandle hSCManager,
+ string lpServiceName,
+ string lpDisplayName,
+ ServiceRights dwDesiredAccess,
+ ServiceType dwServiceType,
+ ServiceStartType dwStartType,
+ ErrorControl dwErrorControl,
+ string lpBinaryPathName,
+ string lpLoadOrderGroup,
+ IntPtr lpdwTagId,
+ string lpDependencies,
+ string lpServiceStartName,
+ string lpPassword);
+
+ [DllImport("Advapi32.dll", SetLastError = true)]
+ public static extern bool DeleteService(
+ SafeHandle hService);
+
+ [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
+ public static extern bool EnumDependentServicesW(
+ SafeHandle hService,
+ UInt32 dwServiceState,
+ SafeMemoryBuffer lpServices,
+ UInt32 cbBufSize,
+ out UInt32 pcbBytesNeeded,
+ out UInt32 lpServicesReturned);
+
+ [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
+ public static extern SafeServiceHandle OpenSCManagerW(
+ string lpMachineName,
+ string lpDatabaseNmae,
+ SCMRights dwDesiredAccess);
+
+ [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
+ public static extern SafeServiceHandle OpenServiceW(
+ SafeHandle hSCManager,
+ string lpServiceName,
+ ServiceRights dwDesiredAccess);
+
+ [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
+ public static extern bool QueryServiceConfigW(
+ SafeHandle hService,
+ IntPtr lpServiceConfig,
+ UInt32 cbBufSize,
+ out UInt32 pcbBytesNeeded);
+
+ [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
+ public static extern bool QueryServiceConfig2W(
+ SafeHandle hservice,
+ NativeHelpers.ConfigInfoLevel dwInfoLevel,
+ IntPtr lpBuffer,
+ UInt32 cbBufSize,
+ out UInt32 pcbBytesNeeded);
+
+ [DllImport("Advapi32.dll", SetLastError = true)]
+ public static extern bool QueryServiceStatusEx(
+ SafeHandle hService,
+ UInt32 InfoLevel,
+ IntPtr lpBuffer,
+ UInt32 cbBufSize,
+ out UInt32 pcbBytesNeeded);
+ }
+
+ internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ public UInt32 BufferLength { get; internal set; }
+
+ public SafeMemoryBuffer() : base(true) { }
+ public SafeMemoryBuffer(int cb) : base(true)
+ {
+ BufferLength = (UInt32)cb;
+ base.SetHandle(Marshal.AllocHGlobal(cb));
+ }
+ public SafeMemoryBuffer(IntPtr handle) : base(true)
+ {
+ base.SetHandle(handle);
+ }
+
+ [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)]
+ protected override bool ReleaseHandle()
+ {
+ Marshal.FreeHGlobal(handle);
+ return true;
+ }
+ }
+
+ internal class SafeServiceHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ public SafeServiceHandle() : base(true) { }
+ public SafeServiceHandle(IntPtr handle) : base(true) { this.handle = handle; }
+
+ [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)]
+ protected override bool ReleaseHandle()
+ {
+ return NativeMethods.CloseServiceHandle(handle);
+ }
+ }
+
+ [Flags]
+ public enum ControlsAccepted : uint
+ {
+ None = 0x00000000,
+ Stop = 0x00000001,
+ PauseContinue = 0x00000002,
+ Shutdown = 0x00000004,
+ ParamChange = 0x00000008,
+ NetbindChange = 0x00000010,
+ HardwareProfileChange = 0x00000020,
+ PowerEvent = 0x00000040,
+ SessionChange = 0x00000080,
+ PreShutdown = 0x00000100,
+ }
+
+ public enum ErrorControl : uint
+ {
+ Ignore = 0x00000000,
+ Normal = 0x00000001,
+ Severe = 0x00000002,
+ Critical = 0x00000003,
+ }
+
+ public enum FailureAction : uint
+ {
+ None = 0x00000000,
+ Restart = 0x00000001,
+ Reboot = 0x00000002,
+ RunCommand = 0x00000003,
+ }
+
+ public enum LaunchProtection : uint
+ {
+ None = 0,
+ Windows = 1,
+ WindowsLight = 2,
+ AntimalwareLight = 3,
+ }
+
+ [Flags]
+ public enum SCMRights : uint
+ {
+ Connect = 0x00000001,
+ CreateService = 0x00000002,
+ EnumerateService = 0x00000004,
+ Lock = 0x00000008,
+ QueryLockStatus = 0x00000010,
+ ModifyBootConfig = 0x00000020,
+ AllAccess = 0x000F003F,
+ }
+
+ [Flags]
+ public enum ServiceFlags : uint
+ {
+ None = 0x0000000,
+ RunsInSystemProcess = 0x00000001,
+ }
+
+ [Flags]
+ public enum ServiceRights : uint
+ {
+ QueryConfig = 0x00000001,
+ ChangeConfig = 0x00000002,
+ QueryStatus = 0x00000004,
+ EnumerateDependents = 0x00000008,
+ Start = 0x00000010,
+ Stop = 0x00000020,
+ PauseContinue = 0x00000040,
+ Interrogate = 0x00000080,
+ UserDefinedControl = 0x00000100,
+ Delete = 0x00010000,
+ ReadControl = 0x00020000,
+ WriteDac = 0x00040000,
+ WriteOwner = 0x00080000,
+ AllAccess = 0x000F01FF,
+ AccessSystemSecurity = 0x01000000,
+ }
+
+ public enum ServiceStartType : uint
+ {
+ BootStart = 0x00000000,
+ SystemStart = 0x00000001,
+ AutoStart = 0x00000002,
+ DemandStart = 0x00000003,
+ Disabled = 0x00000004,
+
+ // Not part of ChangeServiceConfig enumeration but built by the Srvice class for the StartType property.
+ AutoStartDelayed = 0x1000000
+ }
+
+ [Flags]
+ public enum ServiceType : uint
+ {
+ KernelDriver = 0x00000001,
+ FileSystemDriver = 0x00000002,
+ Adapter = 0x00000004,
+ RecognizerDriver = 0x00000008,
+ Driver = KernelDriver | FileSystemDriver | RecognizerDriver,
+ Win32OwnProcess = 0x00000010,
+ Win32ShareProcess = 0x00000020,
+ Win32 = Win32OwnProcess | Win32ShareProcess,
+ UserProcess = 0x00000040,
+ UserOwnprocess = Win32OwnProcess | UserProcess,
+ UserShareProcess = Win32ShareProcess | UserProcess,
+ UserServiceInstance = 0x00000080,
+ InteractiveProcess = 0x00000100,
+ PkgService = 0x00000200,
+ }
+
+ public enum ServiceSidInfo : uint
+ {
+ None,
+ Unrestricted,
+ Restricted = 3,
+ }
+
+ public enum ServiceStatus : uint
+ {
+ Stopped = 0x00000001,
+ StartPending = 0x00000002,
+ StopPending = 0x00000003,
+ Running = 0x00000004,
+ ContinuePending = 0x00000005,
+ PausePending = 0x00000006,
+ Paused = 0x00000007,
+ }
+
+ public enum TriggerAction : uint
+ {
+ ServiceStart = 0x00000001,
+ ServiceStop = 0x000000002,
+ }
+
+ public enum TriggerDataType : uint
+ {
+ Binary = 00000001,
+ String = 0x00000002,
+ Level = 0x00000003,
+ KeywordAny = 0x00000004,
+ KeywordAll = 0x00000005,
+ }
+
+ public enum TriggerType : uint
+ {
+ DeviceInterfaceArrival = 0x00000001,
+ IpAddressAvailability = 0x00000002,
+ DomainJoin = 0x00000003,
+ FirewallPortEvent = 0x00000004,
+ GroupPolicy = 0x00000005,
+ NetworkEndpoint = 0x00000006,
+ Custom = 0x00000014,
+ }
+
+ public class ServiceManagerException : System.ComponentModel.Win32Exception
+ {
+ private string _msg;
+
+ public ServiceManagerException(string message) : this(Marshal.GetLastWin32Error(), message) { }
+ public ServiceManagerException(int errorCode, string message) : base(errorCode)
+ {
+ _msg = String.Format("{0} ({1}, Win32ErrorCode {2} - 0x{2:X8})", message, base.Message, errorCode);
+ }
+
+ public override string Message { get { return _msg; } }
+ public static explicit operator ServiceManagerException(string message)
+ {
+ return new ServiceManagerException(message);
+ }
+ }
+
+ public class Action
+ {
+ public FailureAction Type;
+ public UInt32 Delay;
+ }
+
+ public class FailureActions
+ {
+ public UInt32? ResetPeriod = null; // Get is always populated, can be null on set to preserve existing.
+ public string RebootMsg = null;
+ public string Command = null;
+ public List<Action> Actions = null;
+
+ public FailureActions() { }
+
+ internal FailureActions(NativeHelpers.SERVICE_FAILURE_ACTIONSW actions)
+ {
+ ResetPeriod = actions.dwResetPeriod;
+ RebootMsg = actions.lpRebootMsg;
+ Command = actions.lpCommand;
+ Actions = new List<Action>();
+
+ int actionLength = Marshal.SizeOf(typeof(NativeHelpers.SC_ACTION));
+ for (int i = 0; i < actions.cActions; i++)
+ {
+ IntPtr actionPtr = IntPtr.Add(actions.lpsaActions, i * actionLength);
+
+ NativeHelpers.SC_ACTION rawAction = (NativeHelpers.SC_ACTION)Marshal.PtrToStructure(
+ actionPtr, typeof(NativeHelpers.SC_ACTION));
+
+ Actions.Add(new Action()
+ {
+ Type = rawAction.Type,
+ Delay = rawAction.Delay,
+ });
+ }
+ }
+ }
+
+ public class TriggerItem
+ {
+ public TriggerDataType Type;
+ public object Data; // Can be string, List<string>, byte, byte[], or Int64 depending on Type.
+
+ public TriggerItem() { }
+
+ internal TriggerItem(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM dataItem)
+ {
+ Type = dataItem.dwDataType;
+
+ byte[] itemBytes = new byte[dataItem.cbData];
+ Marshal.Copy(dataItem.pData, itemBytes, 0, itemBytes.Length);
+
+ switch (dataItem.dwDataType)
+ {
+ case TriggerDataType.String:
+ string value = Encoding.Unicode.GetString(itemBytes, 0, itemBytes.Length);
+
+ if (value.EndsWith("\0\0"))
+ {
+ // Multistring with a delimiter of \0 and terminated with \0\0.
+ Data = new List<string>(value.Split(new char[1] { '\0' }, StringSplitOptions.RemoveEmptyEntries));
+ }
+ else
+ // Just a single string with null character at the end, strip it off.
+ Data = value.Substring(0, value.Length - 1);
+ break;
+ case TriggerDataType.Level:
+ Data = itemBytes[0];
+ break;
+ case TriggerDataType.KeywordAll:
+ case TriggerDataType.KeywordAny:
+ Data = BitConverter.ToUInt64(itemBytes, 0);
+ break;
+ default:
+ Data = itemBytes;
+ break;
+ }
+ }
+ }
+
+ public class Trigger
+ {
+ // https://docs.microsoft.com/en-us/windows/win32/api/winsvc/ns-winsvc-service_trigger
+ public const string NAMED_PIPE_EVENT_GUID = "1f81d131-3fac-4537-9e0c-7e7b0c2f4b55";
+ public const string RPC_INTERFACE_EVENT_GUID = "bc90d167-9470-4139-a9ba-be0bbbf5b74d";
+ public const string DOMAIN_JOIN_GUID = "1ce20aba-9851-4421-9430-1ddeb766e809";
+ public const string DOMAIN_LEAVE_GUID = "ddaf516e-58c2-4866-9574-c3b615d42ea1";
+ public const string FIREWALL_PORT_OPEN_GUID = "b7569e07-8421-4ee0-ad10-86915afdad09";
+ public const string FIREWALL_PORT_CLOSE_GUID = "a144ed38-8e12-4de4-9d96-e64740b1a524";
+ public const string MACHINE_POLICY_PRESENT_GUID = "659fcae6-5bdb-4da9-b1ff-ca2a178d46e0";
+ public const string NETWORK_MANAGER_FIRST_IP_ADDRESS_ARRIVAL_GUID = "4f27f2de-14e2-430b-a549-7cd48cbc8245";
+ public const string NETWORK_MANAGER_LAST_IP_ADDRESS_REMOVAL_GUID = "cc4ba62a-162e-4648-847a-b6bdf993e335";
+ public const string USER_POLICY_PRESENT_GUID = "54fb46c8-f089-464c-b1fd-59d1b62c3b50";
+
+ public TriggerType Type;
+ public TriggerAction Action;
+ public Guid SubType;
+ public List<TriggerItem> DataItems = new List<TriggerItem>();
+
+ public Trigger() { }
+
+ internal Trigger(NativeHelpers.SERVICE_TRIGGER trigger)
+ {
+ Type = trigger.dwTriggerType;
+ Action = trigger.dwAction;
+ SubType = (Guid)Marshal.PtrToStructure(trigger.pTriggerSubtype, typeof(Guid));
+
+ int dataItemLength = Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM));
+ for (int i = 0; i < trigger.cDataItems; i++)
+ {
+ IntPtr dataPtr = IntPtr.Add(trigger.pDataItems, i * dataItemLength);
+
+ var dataItem = (NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM)Marshal.PtrToStructure(
+ dataPtr, typeof(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM));
+
+ DataItems.Add(new TriggerItem(dataItem));
+ }
+ }
+ }
+
+ public class Service : IDisposable
+ {
+ private const UInt32 SERVICE_NO_CHANGE = 0xFFFFFFFF;
+
+ private SafeServiceHandle _scmHandle;
+ private SafeServiceHandle _serviceHandle;
+ private SafeMemoryBuffer _rawServiceConfig;
+ private NativeHelpers.SERVICE_STATUS_PROCESS _statusProcess;
+
+ private NativeHelpers.QUERY_SERVICE_CONFIGW _ServiceConfig
+ {
+ get
+ {
+ return (NativeHelpers.QUERY_SERVICE_CONFIGW)Marshal.PtrToStructure(
+ _rawServiceConfig.DangerousGetHandle(), typeof(NativeHelpers.QUERY_SERVICE_CONFIGW));
+ }
+ }
+
+ // ServiceConfig
+ public string ServiceName { get; private set; }
+
+ public ServiceType ServiceType
+ {
+ get { return _ServiceConfig.dwServiceType; }
+ set { ChangeServiceConfig(serviceType: value); }
+ }
+
+ public ServiceStartType StartType
+ {
+ get
+ {
+ ServiceStartType startType = _ServiceConfig.dwStartType;
+ if (startType == ServiceStartType.AutoStart)
+ {
+ var value = QueryServiceConfig2<NativeHelpers.SERVICE_DELAYED_AUTO_START_INFO>(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DELAYED_AUTO_START_INFO);
+
+ if (value.fDelayedAutostart)
+ startType = ServiceStartType.AutoStartDelayed;
+ }
+
+ return startType;
+ }
+ set
+ {
+ ServiceStartType newStartType = value;
+ bool delayedStart = false;
+ if (value == ServiceStartType.AutoStartDelayed)
+ {
+ newStartType = ServiceStartType.AutoStart;
+ delayedStart = true;
+ }
+
+ ChangeServiceConfig(startType: newStartType);
+
+ var info = new NativeHelpers.SERVICE_DELAYED_AUTO_START_INFO()
+ {
+ fDelayedAutostart = delayedStart,
+ };
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DELAYED_AUTO_START_INFO, info);
+ }
+ }
+
+ public ErrorControl ErrorControl
+ {
+ get { return _ServiceConfig.dwErrorControl; }
+ set { ChangeServiceConfig(errorControl: value); }
+ }
+
+ public string Path
+ {
+ get { return _ServiceConfig.lpBinaryPathName; }
+ set { ChangeServiceConfig(binaryPath: value); }
+ }
+
+ public string LoadOrderGroup
+ {
+ get { return _ServiceConfig.lpLoadOrderGroup; }
+ set { ChangeServiceConfig(loadOrderGroup: value); }
+ }
+
+ public List<string> DependentOn
+ {
+ get
+ {
+ StringBuilder deps = new StringBuilder();
+ IntPtr depPtr = _ServiceConfig.lpDependencies;
+
+ bool wasNull = false;
+ while (true)
+ {
+ // Get the current char at the pointer and add it to the StringBuilder.
+ byte[] charBytes = new byte[sizeof(char)];
+ Marshal.Copy(depPtr, charBytes, 0, charBytes.Length);
+ depPtr = IntPtr.Add(depPtr, charBytes.Length);
+ char currentChar = BitConverter.ToChar(charBytes, 0);
+ deps.Append(currentChar);
+
+ // If the previous and current char is \0 exit the loop.
+ if (currentChar == '\0' && wasNull)
+ break;
+ wasNull = currentChar == '\0';
+ }
+
+ return new List<string>(deps.ToString().Split(new char[1] { '\0' },
+ StringSplitOptions.RemoveEmptyEntries));
+ }
+ set { ChangeServiceConfig(dependencies: value); }
+ }
+
+ public IdentityReference Account
+ {
+ get
+ {
+ if (_ServiceConfig.lpServiceStartName == null)
+ // User services don't have the start name specified and will be null.
+ return null;
+ else if (_ServiceConfig.lpServiceStartName == "LocalSystem")
+ // Special string used for the SYSTEM account, this is the same even for different localisations.
+ return (NTAccount)new SecurityIdentifier("S-1-5-18").Translate(typeof(NTAccount));
+ else
+ return new NTAccount(_ServiceConfig.lpServiceStartName);
+ }
+ set
+ {
+ string startName = null;
+ string pass = null;
+
+ if (value != null)
+ {
+ // Create a SID and convert back from a SID to get the Netlogon form regardless of the input
+ // specified.
+ SecurityIdentifier accountSid = (SecurityIdentifier)value.Translate(typeof(SecurityIdentifier));
+ NTAccount accountName = (NTAccount)accountSid.Translate(typeof(NTAccount));
+ string[] accountSplit = accountName.Value.Split(new char[1] { '\\' }, 2);
+
+ // SYSTEM, Local Service, Network Service
+ List<string> serviceAccounts = new List<string> { "S-1-5-18", "S-1-5-19", "S-1-5-20" };
+
+ // Well known service accounts and MSAs should have no password set. Explicitly blank out the
+ // existing password to ensure older passwords are no longer stored by Windows.
+ if (serviceAccounts.Contains(accountSid.Value) || accountSplit[1].EndsWith("$"))
+ pass = "";
+
+ // The SYSTEM account uses this special string to specify that account otherwise use the original
+ // NTAccount value in case it is in a custom format (not Netlogon) for a reason.
+ if (accountSid.Value == serviceAccounts[0])
+ startName = "LocalSystem";
+ else
+ startName = value.Translate(typeof(NTAccount)).Value;
+ }
+
+ ChangeServiceConfig(startName: startName, password: pass);
+ }
+ }
+
+ public string Password { set { ChangeServiceConfig(password: value); } }
+
+ public string DisplayName
+ {
+ get { return _ServiceConfig.lpDisplayName; }
+ set { ChangeServiceConfig(displayName: value); }
+ }
+
+ // ServiceConfig2
+
+ public string Description
+ {
+ get
+ {
+ var value = QueryServiceConfig2<NativeHelpers.SERVICE_DESCRIPTIONW>(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DESCRIPTION);
+
+ return value.lpDescription;
+ }
+ set
+ {
+ var info = new NativeHelpers.SERVICE_DESCRIPTIONW()
+ {
+ lpDescription = value,
+ };
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DESCRIPTION, info);
+ }
+ }
+
+ public FailureActions FailureActions
+ {
+ get
+ {
+ using (SafeMemoryBuffer b = QueryServiceConfig2(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS))
+ {
+ NativeHelpers.SERVICE_FAILURE_ACTIONSW value = (NativeHelpers.SERVICE_FAILURE_ACTIONSW)
+ Marshal.PtrToStructure(b.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_FAILURE_ACTIONSW));
+
+ return new FailureActions(value);
+ }
+ }
+ set
+ {
+ // dwResetPeriod and lpsaActions must be set together, we need to read the existing config if someone
+ // wants to update 1 or the other but both aren't explicitly defined.
+ UInt32? resetPeriod = value.ResetPeriod;
+ List<Action> actions = value.Actions;
+ if ((resetPeriod != null && actions == null) || (resetPeriod == null && actions != null))
+ {
+ FailureActions existingValue = this.FailureActions;
+
+ if (resetPeriod != null && existingValue.Actions.Count == 0)
+ throw new ArgumentException(
+ "Cannot set FailureAction ResetPeriod without explicit Actions and no existing Actions");
+ else if (resetPeriod == null)
+ resetPeriod = (UInt32)existingValue.ResetPeriod;
+
+ if (actions == null)
+ actions = existingValue.Actions;
+ }
+
+ var info = new NativeHelpers.SERVICE_FAILURE_ACTIONSW()
+ {
+ dwResetPeriod = resetPeriod == null ? 0 : (UInt32)resetPeriod,
+ lpRebootMsg = value.RebootMsg,
+ lpCommand = value.Command,
+ cActions = actions == null ? 0 : (UInt32)actions.Count,
+ lpsaActions = IntPtr.Zero,
+ };
+
+ // null means to keep the existing actions whereas an empty list deletes the actions.
+ if (actions == null)
+ {
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS, info);
+ return;
+ }
+
+ int actionLength = Marshal.SizeOf(typeof(NativeHelpers.SC_ACTION));
+ using (SafeMemoryBuffer buffer = new SafeMemoryBuffer(actionLength * actions.Count))
+ {
+ info.lpsaActions = buffer.DangerousGetHandle();
+ HashSet<string> privileges = new HashSet<string>();
+
+ for (int i = 0; i < actions.Count; i++)
+ {
+ IntPtr actionPtr = IntPtr.Add(info.lpsaActions, i * actionLength);
+ NativeHelpers.SC_ACTION action = new NativeHelpers.SC_ACTION()
+ {
+ Delay = actions[i].Delay,
+ Type = actions[i].Type,
+ };
+ Marshal.StructureToPtr(action, actionPtr, false);
+
+ // Need to make sure the SeShutdownPrivilege is enabled when adding a reboot failure action.
+ if (action.Type == FailureAction.Reboot)
+ privileges.Add("SeShutdownPrivilege");
+ }
+
+ using (new PrivilegeEnabler(true, privileges.ToList().ToArray()))
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS, info);
+ }
+ }
+ }
+
+ public bool FailureActionsOnNonCrashFailures
+ {
+ get
+ {
+ var value = QueryServiceConfig2<NativeHelpers.SERVICE_FAILURE_ACTIONS_FLAG>(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG);
+
+ return value.fFailureActionsOnNonCrashFailures;
+ }
+ set
+ {
+ var info = new NativeHelpers.SERVICE_FAILURE_ACTIONS_FLAG()
+ {
+ fFailureActionsOnNonCrashFailures = value,
+ };
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG, info);
+ }
+ }
+
+ public ServiceSidInfo ServiceSidInfo
+ {
+ get
+ {
+ var value = QueryServiceConfig2<NativeHelpers.SERVICE_SID_INFO>(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_SERVICE_SID_INFO);
+
+ return value.dwServiceSidType;
+ }
+ set
+ {
+ var info = new NativeHelpers.SERVICE_SID_INFO()
+ {
+ dwServiceSidType = value,
+ };
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_SERVICE_SID_INFO, info);
+ }
+ }
+
+ public List<string> RequiredPrivileges
+ {
+ get
+ {
+ using (SafeMemoryBuffer buffer = QueryServiceConfig2(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_REQUIRED_PRIVILEGES_INFO))
+ {
+ var value = (NativeHelpers.SERVICE_REQUIRED_PRIVILEGES_INFOW)Marshal.PtrToStructure(
+ buffer.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_REQUIRED_PRIVILEGES_INFOW));
+
+ int structLength = Marshal.SizeOf(value);
+ int stringLength = ((int)buffer.BufferLength - structLength) / sizeof(char);
+
+ if (stringLength > 0)
+ {
+ string privilegesString = Marshal.PtrToStringUni(value.pmszRequiredPrivileges, stringLength);
+ return new List<string>(privilegesString.Split(new char[1] { '\0' },
+ StringSplitOptions.RemoveEmptyEntries));
+ }
+ else
+ return new List<string>();
+ }
+ }
+ set
+ {
+ string privilegeString = String.Join("\0", value ?? new List<string>()) + "\0\0";
+
+ using (SafeMemoryBuffer buffer = new SafeMemoryBuffer(Marshal.StringToHGlobalUni(privilegeString)))
+ {
+ var info = new NativeHelpers.SERVICE_REQUIRED_PRIVILEGES_INFOW()
+ {
+ pmszRequiredPrivileges = buffer.DangerousGetHandle(),
+ };
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_REQUIRED_PRIVILEGES_INFO, info);
+ }
+ }
+ }
+
+ public UInt32 PreShutdownTimeout
+ {
+ get
+ {
+ var value = QueryServiceConfig2<NativeHelpers.SERVICE_PRESHUTDOWN_INFO>(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PRESHUTDOWN_INFO);
+
+ return value.dwPreshutdownTimeout;
+ }
+ set
+ {
+ var info = new NativeHelpers.SERVICE_PRESHUTDOWN_INFO()
+ {
+ dwPreshutdownTimeout = value,
+ };
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PRESHUTDOWN_INFO, info);
+ }
+ }
+
+ public List<Trigger> Triggers
+ {
+ get
+ {
+ List<Trigger> triggers = new List<Trigger>();
+
+ using (SafeMemoryBuffer b = QueryServiceConfig2(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_TRIGGER_INFO))
+ {
+ var value = (NativeHelpers.SERVICE_TRIGGER_INFO)Marshal.PtrToStructure(
+ b.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_TRIGGER_INFO));
+
+ int triggerLength = Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER));
+ for (int i = 0; i < value.cTriggers; i++)
+ {
+ IntPtr triggerPtr = IntPtr.Add(value.pTriggers, i * triggerLength);
+ var trigger = (NativeHelpers.SERVICE_TRIGGER)Marshal.PtrToStructure(triggerPtr,
+ typeof(NativeHelpers.SERVICE_TRIGGER));
+
+ triggers.Add(new Trigger(trigger));
+ }
+ }
+
+ return triggers;
+ }
+ set
+ {
+ var info = new NativeHelpers.SERVICE_TRIGGER_INFO()
+ {
+ cTriggers = value == null ? 0 : (UInt32)value.Count,
+ pTriggers = IntPtr.Zero,
+ pReserved = IntPtr.Zero,
+ };
+
+ if (info.cTriggers == 0)
+ {
+ try
+ {
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_TRIGGER_INFO, info);
+ }
+ catch (ServiceManagerException e)
+ {
+ // Can fail with ERROR_INVALID_PARAMETER if no triggers were already set on the service, just
+ // continue as the service is what we want it to be.
+ if (e.NativeErrorCode != 87)
+ throw;
+ }
+ return;
+ }
+
+ // Due to the dynamic nature of the trigger structure(s) we need to manually calculate the size of the
+ // data items on each trigger if present. This also serializes the raw data items to bytes here.
+ int structDataLength = 0;
+ int dataLength = 0;
+ Queue<byte[]> dataItems = new Queue<byte[]>();
+ foreach (Trigger trigger in value)
+ {
+ if (trigger.DataItems == null || trigger.DataItems.Count == 0)
+ continue;
+
+ foreach (TriggerItem dataItem in trigger.DataItems)
+ {
+ structDataLength += Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM));
+
+ byte[] dataItemBytes;
+ Type dataItemType = dataItem.Data.GetType();
+ if (dataItemType == typeof(byte))
+ dataItemBytes = new byte[1] { (byte)dataItem.Data };
+ else if (dataItemType == typeof(byte[]))
+ dataItemBytes = (byte[])dataItem.Data;
+ else if (dataItemType == typeof(UInt64))
+ dataItemBytes = BitConverter.GetBytes((UInt64)dataItem.Data);
+ else if (dataItemType == typeof(string))
+ dataItemBytes = Encoding.Unicode.GetBytes((string)dataItem.Data + "\0");
+ else if (dataItemType == typeof(List<string>))
+ dataItemBytes = Encoding.Unicode.GetBytes(
+ String.Join("\0", (List<string>)dataItem.Data) + "\0");
+ else
+ throw new ArgumentException(String.Format("Trigger data type '{0}' not a value type",
+ dataItemType.Name));
+
+ dataLength += dataItemBytes.Length;
+ dataItems.Enqueue(dataItemBytes);
+ }
+ }
+
+ using (SafeMemoryBuffer triggerBuffer = new SafeMemoryBuffer(
+ value.Count * Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER))))
+ using (SafeMemoryBuffer triggerGuidBuffer = new SafeMemoryBuffer(
+ value.Count * Marshal.SizeOf(typeof(Guid))))
+ using (SafeMemoryBuffer dataItemBuffer = new SafeMemoryBuffer(structDataLength))
+ using (SafeMemoryBuffer dataBuffer = new SafeMemoryBuffer(dataLength))
+ {
+ info.pTriggers = triggerBuffer.DangerousGetHandle();
+
+ IntPtr triggerPtr = triggerBuffer.DangerousGetHandle();
+ IntPtr guidPtr = triggerGuidBuffer.DangerousGetHandle();
+ IntPtr dataItemPtr = dataItemBuffer.DangerousGetHandle();
+ IntPtr dataPtr = dataBuffer.DangerousGetHandle();
+
+ foreach (Trigger trigger in value)
+ {
+ int dataCount = trigger.DataItems == null ? 0 : trigger.DataItems.Count;
+ var rawTrigger = new NativeHelpers.SERVICE_TRIGGER()
+ {
+ dwTriggerType = trigger.Type,
+ dwAction = trigger.Action,
+ pTriggerSubtype = guidPtr,
+ cDataItems = (UInt32)dataCount,
+ pDataItems = dataCount == 0 ? IntPtr.Zero : dataItemPtr,
+ };
+ guidPtr = StructureToPtr(trigger.SubType, guidPtr);
+
+ for (int i = 0; i < rawTrigger.cDataItems; i++)
+ {
+ byte[] dataItemBytes = dataItems.Dequeue();
+ var rawTriggerData = new NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM()
+ {
+ dwDataType = trigger.DataItems[i].Type,
+ cbData = (UInt32)dataItemBytes.Length,
+ pData = dataPtr,
+ };
+ Marshal.Copy(dataItemBytes, 0, dataPtr, dataItemBytes.Length);
+ dataPtr = IntPtr.Add(dataPtr, dataItemBytes.Length);
+
+ dataItemPtr = StructureToPtr(rawTriggerData, dataItemPtr);
+ }
+
+ triggerPtr = StructureToPtr(rawTrigger, triggerPtr);
+ }
+
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_TRIGGER_INFO, info);
+ }
+ }
+ }
+
+ public UInt16? PreferredNode
+ {
+ get
+ {
+ try
+ {
+ var value = QueryServiceConfig2<NativeHelpers.SERVICE_PREFERRED_NODE_INFO>(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PREFERRED_NODE);
+
+ return value.usPreferredNode;
+ }
+ catch (ServiceManagerException e)
+ {
+ // If host has no NUMA support this will fail with ERROR_INVALID_PARAMETER
+ if (e.NativeErrorCode == 0x00000057) // ERROR_INVALID_PARAMETER
+ return null;
+
+ throw;
+ }
+ }
+ set
+ {
+ var info = new NativeHelpers.SERVICE_PREFERRED_NODE_INFO();
+ if (value == null)
+ info.fDelete = true;
+ else
+ info.usPreferredNode = (UInt16)value;
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PREFERRED_NODE, info);
+ }
+ }
+
+ public LaunchProtection LaunchProtection
+ {
+ get
+ {
+ var value = QueryServiceConfig2<NativeHelpers.SERVICE_LAUNCH_PROTECTED_INFO>(
+ NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_LAUNCH_PROTECTED);
+
+ return value.dwLaunchProtected;
+ }
+ set
+ {
+ var info = new NativeHelpers.SERVICE_LAUNCH_PROTECTED_INFO()
+ {
+ dwLaunchProtected = value,
+ };
+ ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_LAUNCH_PROTECTED, info);
+ }
+ }
+
+ // ServiceStatus
+ public ServiceStatus State { get { return _statusProcess.dwCurrentState; } }
+
+ public ControlsAccepted ControlsAccepted { get { return _statusProcess.dwControlsAccepted; } }
+
+ public UInt32 Win32ExitCode { get { return _statusProcess.dwWin32ExitCode; } }
+
+ public UInt32 ServiceExitCode { get { return _statusProcess.dwServiceSpecificExitCode; } }
+
+ public UInt32 Checkpoint { get { return _statusProcess.dwCheckPoint; } }
+
+ public UInt32 WaitHint { get { return _statusProcess.dwWaitHint; } }
+
+ public UInt32 ProcessId { get { return _statusProcess.dwProcessId; } }
+
+ public ServiceFlags ServiceFlags { get { return _statusProcess.dwServiceFlags; } }
+
+ public Service(string name) : this(name, ServiceRights.AllAccess) { }
+
+ public Service(string name, ServiceRights access) : this(name, access, SCMRights.Connect) { }
+
+ public Service(string name, ServiceRights access, SCMRights scmAccess)
+ {
+ ServiceName = name;
+ _scmHandle = OpenSCManager(scmAccess);
+ _serviceHandle = NativeMethods.OpenServiceW(_scmHandle, name, access);
+ if (_serviceHandle.IsInvalid)
+ throw new ServiceManagerException(String.Format("Failed to open service '{0}'", name));
+
+ Refresh();
+ }
+
+ private Service(SafeServiceHandle scmHandle, SafeServiceHandle serviceHandle, string name)
+ {
+ ServiceName = name;
+ _scmHandle = scmHandle;
+ _serviceHandle = serviceHandle;
+
+ Refresh();
+ }
+
+ // EnumDependentServices
+ public List<string> DependedBy
+ {
+ get
+ {
+ UInt32 bytesNeeded = 0;
+ UInt32 numServices = 0;
+ NativeMethods.EnumDependentServicesW(_serviceHandle, 3, new SafeMemoryBuffer(IntPtr.Zero), 0,
+ out bytesNeeded, out numServices);
+
+ using (SafeMemoryBuffer buffer = new SafeMemoryBuffer((int)bytesNeeded))
+ {
+ if (!NativeMethods.EnumDependentServicesW(_serviceHandle, 3, buffer, bytesNeeded, out bytesNeeded,
+ out numServices))
+ {
+ throw new ServiceManagerException("Failed to enumerated dependent services");
+ }
+
+ List<string> dependents = new List<string>();
+ Type enumType = typeof(NativeHelpers.ENUM_SERVICE_STATUSW);
+ for (int i = 0; i < numServices; i++)
+ {
+ var service = (NativeHelpers.ENUM_SERVICE_STATUSW)Marshal.PtrToStructure(
+ IntPtr.Add(buffer.DangerousGetHandle(), i * Marshal.SizeOf(enumType)), enumType);
+
+ dependents.Add(service.lpServiceName);
+ }
+
+ return dependents;
+ }
+ }
+ }
+
+ public static Service Create(string name, string binaryPath, string displayName = null,
+ ServiceType serviceType = ServiceType.Win32OwnProcess,
+ ServiceStartType startType = ServiceStartType.DemandStart, ErrorControl errorControl = ErrorControl.Normal,
+ string loadOrderGroup = null, List<string> dependencies = null, string startName = null,
+ string password = null)
+ {
+ SafeServiceHandle scmHandle = OpenSCManager(SCMRights.CreateService | SCMRights.Connect);
+
+ if (displayName == null)
+ displayName = name;
+
+ string depString = null;
+ if (dependencies != null && dependencies.Count > 0)
+ depString = String.Join("\0", dependencies) + "\0\0";
+
+ SafeServiceHandle serviceHandle = NativeMethods.CreateServiceW(scmHandle, name, displayName,
+ ServiceRights.AllAccess, serviceType, startType, errorControl, binaryPath,
+ loadOrderGroup, IntPtr.Zero, depString, startName, password);
+
+ if (serviceHandle.IsInvalid)
+ throw new ServiceManagerException(String.Format("Failed to create new service '{0}'", name));
+
+ return new Service(scmHandle, serviceHandle, name);
+ }
+
+ public void Delete()
+ {
+ if (!NativeMethods.DeleteService(_serviceHandle))
+ throw new ServiceManagerException("Failed to delete service");
+ Dispose();
+ }
+
+ public void Dispose()
+ {
+ if (_serviceHandle != null)
+ _serviceHandle.Dispose();
+
+ if (_scmHandle != null)
+ _scmHandle.Dispose();
+ GC.SuppressFinalize(this);
+ }
+
+ public void Refresh()
+ {
+ UInt32 bytesNeeded;
+ NativeMethods.QueryServiceConfigW(_serviceHandle, IntPtr.Zero, 0, out bytesNeeded);
+
+ _rawServiceConfig = new SafeMemoryBuffer((int)bytesNeeded);
+ if (!NativeMethods.QueryServiceConfigW(_serviceHandle, _rawServiceConfig.DangerousGetHandle(), bytesNeeded,
+ out bytesNeeded))
+ {
+ throw new ServiceManagerException("Failed to query service config");
+ }
+
+ NativeMethods.QueryServiceStatusEx(_serviceHandle, 0, IntPtr.Zero, 0, out bytesNeeded);
+ using (SafeMemoryBuffer buffer = new SafeMemoryBuffer((int)bytesNeeded))
+ {
+ if (!NativeMethods.QueryServiceStatusEx(_serviceHandle, 0, buffer.DangerousGetHandle(), bytesNeeded,
+ out bytesNeeded))
+ {
+ throw new ServiceManagerException("Failed to query service status");
+ }
+
+ _statusProcess = (NativeHelpers.SERVICE_STATUS_PROCESS)Marshal.PtrToStructure(
+ buffer.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_STATUS_PROCESS));
+ }
+ }
+
+ private void ChangeServiceConfig(ServiceType serviceType = (ServiceType)SERVICE_NO_CHANGE,
+ ServiceStartType startType = (ServiceStartType)SERVICE_NO_CHANGE,
+ ErrorControl errorControl = (ErrorControl)SERVICE_NO_CHANGE, string binaryPath = null,
+ string loadOrderGroup = null, List<string> dependencies = null, string startName = null,
+ string password = null, string displayName = null)
+ {
+ string depString = null;
+ if (dependencies != null && dependencies.Count > 0)
+ depString = String.Join("\0", dependencies) + "\0\0";
+
+ if (!NativeMethods.ChangeServiceConfigW(_serviceHandle, serviceType, startType, errorControl, binaryPath,
+ loadOrderGroup, IntPtr.Zero, depString, startName, password, displayName))
+ {
+ throw new ServiceManagerException("Failed to change service config");
+ }
+
+ Refresh();
+ }
+
+ private void ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel infoLevel, object info)
+ {
+ using (SafeMemoryBuffer buffer = new SafeMemoryBuffer(Marshal.SizeOf(info)))
+ {
+ Marshal.StructureToPtr(info, buffer.DangerousGetHandle(), false);
+
+ if (!NativeMethods.ChangeServiceConfig2W(_serviceHandle, infoLevel, buffer.DangerousGetHandle()))
+ throw new ServiceManagerException("Failed to change service config");
+ }
+ }
+
+ private static SafeServiceHandle OpenSCManager(SCMRights desiredAccess)
+ {
+ SafeServiceHandle handle = NativeMethods.OpenSCManagerW(null, null, desiredAccess);
+ if (handle.IsInvalid)
+ throw new ServiceManagerException("Failed to open SCManager");
+
+ return handle;
+ }
+
+ private T QueryServiceConfig2<T>(NativeHelpers.ConfigInfoLevel infoLevel)
+ {
+ using (SafeMemoryBuffer buffer = QueryServiceConfig2(infoLevel))
+ return (T)Marshal.PtrToStructure(buffer.DangerousGetHandle(), typeof(T));
+ }
+
+ private SafeMemoryBuffer QueryServiceConfig2(NativeHelpers.ConfigInfoLevel infoLevel)
+ {
+ UInt32 bytesNeeded = 0;
+ NativeMethods.QueryServiceConfig2W(_serviceHandle, infoLevel, IntPtr.Zero, 0, out bytesNeeded);
+
+ SafeMemoryBuffer buffer = new SafeMemoryBuffer((int)bytesNeeded);
+ if (!NativeMethods.QueryServiceConfig2W(_serviceHandle, infoLevel, buffer.DangerousGetHandle(), bytesNeeded,
+ out bytesNeeded))
+ {
+ throw new ServiceManagerException(String.Format("QueryServiceConfig2W({0}) failed",
+ infoLevel.ToString()));
+ }
+
+ return buffer;
+ }
+
+ private static IntPtr StructureToPtr(object structure, IntPtr ptr)
+ {
+ Marshal.StructureToPtr(structure, ptr, false);
+ return IntPtr.Add(ptr, Marshal.SizeOf(structure));
+ }
+
+ ~Service() { Dispose(); }
+ }
+}
diff --git a/test/support/windows-integration/plugins/modules/async_status.ps1 b/test/support/windows-integration/plugins/modules/async_status.ps1
new file mode 100644
index 00000000..1ce3ff40
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/async_status.ps1
@@ -0,0 +1,58 @@
+#!powershell
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+
+$results = @{changed=$false}
+
+$parsed_args = Parse-Args $args
+$jid = Get-AnsibleParam $parsed_args "jid" -failifempty $true -resultobj $results
+$mode = Get-AnsibleParam $parsed_args "mode" -Default "status" -ValidateSet "status","cleanup"
+
+# parsed in from the async_status action plugin
+$async_dir = Get-AnsibleParam $parsed_args "_async_dir" -type "path" -failifempty $true
+
+$log_path = [System.IO.Path]::Combine($async_dir, $jid)
+
+If(-not $(Test-Path $log_path))
+{
+ Fail-Json @{ansible_job_id=$jid; started=1; finished=1} "could not find job at '$async_dir'"
+}
+
+If($mode -eq "cleanup") {
+ Remove-Item $log_path -Recurse
+ Exit-Json @{ansible_job_id=$jid; erased=$log_path}
+}
+
+# NOT in cleanup mode, assume regular status mode
+# no remote kill mode currently exists, but probably should
+# consider log_path + ".pid" file and also unlink that above
+
+$data = $null
+Try {
+ $data_raw = Get-Content $log_path
+
+ # TODO: move this into module_utils/powershell.ps1?
+ $jss = New-Object System.Web.Script.Serialization.JavaScriptSerializer
+ $data = $jss.DeserializeObject($data_raw)
+}
+Catch {
+ If(-not $data_raw) {
+ # file not written yet? That means it is running
+ Exit-Json @{results_file=$log_path; ansible_job_id=$jid; started=1; finished=0}
+ }
+ Else {
+ Fail-Json @{ansible_job_id=$jid; results_file=$log_path; started=1; finished=1} "Could not parse job output: $data"
+ }
+}
+
+If (-not $data.ContainsKey("started")) {
+ $data['finished'] = 1
+ $data['ansible_job_id'] = $jid
+}
+ElseIf (-not $data.ContainsKey("finished")) {
+ $data['finished'] = 0
+}
+
+Exit-Json $data
diff --git a/test/support/windows-integration/plugins/modules/setup.ps1 b/test/support/windows-integration/plugins/modules/setup.ps1
new file mode 100644
index 00000000..50647239
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/setup.ps1
@@ -0,0 +1,516 @@
+#!powershell
+
+# Copyright: (c) 2018, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+
+Function Get-CustomFacts {
+ [cmdletBinding()]
+ param (
+ [Parameter(mandatory=$false)]
+ $factpath = $null
+ )
+
+ if (Test-Path -Path $factpath) {
+ $FactsFiles = Get-ChildItem -Path $factpath | Where-Object -FilterScript {($PSItem.PSIsContainer -eq $false) -and ($PSItem.Extension -eq '.ps1')}
+
+ foreach ($FactsFile in $FactsFiles) {
+ $out = & $($FactsFile.FullName)
+ $result.ansible_facts.Add("ansible_$(($FactsFile.Name).Split('.')[0])", $out)
+ }
+ }
+ else
+ {
+ Add-Warning $result "Non existing path was set for local facts - $factpath"
+ }
+}
+
+Function Get-MachineSid {
+ # The Machine SID is stored in HKLM:\SECURITY\SAM\Domains\Account and is
+ # only accessible by the Local System account. This method get's the local
+ # admin account (ends with -500) and lops it off to get the machine sid.
+
+ $machine_sid = $null
+
+ try {
+ $admins_sid = "S-1-5-32-544"
+ $admin_group = ([Security.Principal.SecurityIdentifier]$admins_sid).Translate([Security.Principal.NTAccount]).Value
+
+ Add-Type -AssemblyName System.DirectoryServices.AccountManagement
+ $principal_context = New-Object -TypeName System.DirectoryServices.AccountManagement.PrincipalContext([System.DirectoryServices.AccountManagement.ContextType]::Machine)
+ $group_principal = New-Object -TypeName System.DirectoryServices.AccountManagement.GroupPrincipal($principal_context, $admin_group)
+ $searcher = New-Object -TypeName System.DirectoryServices.AccountManagement.PrincipalSearcher($group_principal)
+ $groups = $searcher.FindOne()
+
+ foreach ($user in $groups.Members) {
+ $user_sid = $user.Sid
+ if ($user_sid.Value.EndsWith("-500")) {
+ $machine_sid = $user_sid.AccountDomainSid.Value
+ break
+ }
+ }
+ } catch {
+ #can fail for any number of reasons, if it does just return the original null
+ Add-Warning -obj $result -message "Error during machine sid retrieval: $($_.Exception.Message)"
+ }
+
+ return $machine_sid
+}
+
+$cim_instances = @{}
+
+Function Get-LazyCimInstance([string]$instance_name, [string]$namespace="Root\CIMV2") {
+ if(-not $cim_instances.ContainsKey($instance_name)) {
+ $cim_instances[$instance_name] = $(Get-CimInstance -Namespace $namespace -ClassName $instance_name)
+ }
+
+ return $cim_instances[$instance_name]
+}
+
+$result = @{
+ ansible_facts = @{ }
+ changed = $false
+}
+
+$grouped_subsets = @{
+ min=[System.Collections.Generic.List[string]]@('date_time','distribution','dns','env','local','platform','powershell_version','user')
+ network=[System.Collections.Generic.List[string]]@('all_ipv4_addresses','all_ipv6_addresses','interfaces','windows_domain', 'winrm')
+ hardware=[System.Collections.Generic.List[string]]@('bios','memory','processor','uptime','virtual')
+ external=[System.Collections.Generic.List[string]]@('facter')
+}
+
+# build "all" set from everything mentioned in the group- this means every value must be in at least one subset to be considered legal
+$all_set = [System.Collections.Generic.HashSet[string]]@()
+
+foreach($kv in $grouped_subsets.GetEnumerator()) {
+ [void] $all_set.UnionWith($kv.Value)
+}
+
+# dynamically create an "all" subset now that we know what should be in it
+$grouped_subsets['all'] = [System.Collections.Generic.List[string]]$all_set
+
+# start with all, build up gather and exclude subsets
+$gather_subset = [System.Collections.Generic.HashSet[string]]$grouped_subsets.all
+$explicit_subset = [System.Collections.Generic.HashSet[string]]@()
+$exclude_subset = [System.Collections.Generic.HashSet[string]]@()
+
+$params = Parse-Args $args -supports_check_mode $true
+$factpath = Get-AnsibleParam -obj $params -name "fact_path" -type "path"
+$gather_subset_source = Get-AnsibleParam -obj $params -name "gather_subset" -type "list" -default "all"
+
+foreach($item in $gather_subset_source) {
+ if(([string]$item).StartsWith("!")) {
+ $item = ([string]$item).Substring(1)
+ if($item -eq "all") {
+ $all_minus_min = [System.Collections.Generic.HashSet[string]]@($all_set)
+ [void] $all_minus_min.ExceptWith($grouped_subsets.min)
+ [void] $exclude_subset.UnionWith($all_minus_min)
+ }
+ elseif($grouped_subsets.ContainsKey($item)) {
+ [void] $exclude_subset.UnionWith($grouped_subsets[$item])
+ }
+ elseif($all_set.Contains($item)) {
+ [void] $exclude_subset.Add($item)
+ }
+ # NB: invalid exclude values are ignored, since that's what posix setup does
+ }
+ else {
+ if($grouped_subsets.ContainsKey($item)) {
+ [void] $explicit_subset.UnionWith($grouped_subsets[$item])
+ }
+ elseif($all_set.Contains($item)) {
+ [void] $explicit_subset.Add($item)
+ }
+ else {
+ # NB: POSIX setup fails on invalid value; we warn, because we don't implement the same set as POSIX
+ # and we don't have platform-specific config for this...
+ Add-Warning $result "invalid value $item specified in gather_subset"
+ }
+ }
+}
+
+[void] $gather_subset.ExceptWith($exclude_subset)
+[void] $gather_subset.UnionWith($explicit_subset)
+
+$ansible_facts = @{
+ gather_subset=@($gather_subset_source)
+ module_setup=$true
+}
+
+$osversion = [Environment]::OSVersion
+
+if ($osversion.Version -lt [version]"6.2") {
+ # Server 2008, 2008 R2, and Windows 7 are not tested in CI and we want to let customers know about it before
+ # removing support altogether.
+ $version_string = "{0}.{1}" -f ($osversion.Version.Major, $osversion.Version.Minor)
+ $msg = "Windows version '$version_string' will no longer be supported or tested in the next Ansible release"
+ Add-DeprecationWarning -obj $result -message $msg -version "2.11"
+}
+
+if($gather_subset.Contains('all_ipv4_addresses') -or $gather_subset.Contains('all_ipv6_addresses')) {
+ $netcfg = Get-LazyCimInstance Win32_NetworkAdapterConfiguration
+
+ # TODO: split v4/v6 properly, return in separate keys
+ $ips = @()
+ Foreach ($ip in $netcfg.IPAddress) {
+ If ($ip) {
+ $ips += $ip
+ }
+ }
+
+ $ansible_facts += @{
+ ansible_ip_addresses = $ips
+ }
+}
+
+if($gather_subset.Contains('bios')) {
+ $win32_bios = Get-LazyCimInstance Win32_Bios
+ $win32_cs = Get-LazyCimInstance Win32_ComputerSystem
+ $ansible_facts += @{
+ ansible_bios_date = $win32_bios.ReleaseDate.ToString("MM/dd/yyyy")
+ ansible_bios_version = $win32_bios.SMBIOSBIOSVersion
+ ansible_product_name = $win32_cs.Model.Trim()
+ ansible_product_serial = $win32_bios.SerialNumber
+ # ansible_product_version = ([string] $win32_cs.SystemFamily)
+ }
+}
+
+if($gather_subset.Contains('date_time')) {
+ $datetime = (Get-Date)
+ $datetime_utc = $datetime.ToUniversalTime()
+ $date = @{
+ date = $datetime.ToString("yyyy-MM-dd")
+ day = $datetime.ToString("dd")
+ epoch = (Get-Date -UFormat "%s")
+ hour = $datetime.ToString("HH")
+ iso8601 = $datetime_utc.ToString("yyyy-MM-ddTHH:mm:ssZ")
+ iso8601_basic = $datetime.ToString("yyyyMMddTHHmmssffffff")
+ iso8601_basic_short = $datetime.ToString("yyyyMMddTHHmmss")
+ iso8601_micro = $datetime_utc.ToString("yyyy-MM-ddTHH:mm:ss.ffffffZ")
+ minute = $datetime.ToString("mm")
+ month = $datetime.ToString("MM")
+ second = $datetime.ToString("ss")
+ time = $datetime.ToString("HH:mm:ss")
+ tz = ([System.TimeZoneInfo]::Local.Id)
+ tz_offset = $datetime.ToString("zzzz")
+ # Ensure that the weekday is in English
+ weekday = $datetime.ToString("dddd", [System.Globalization.CultureInfo]::InvariantCulture)
+ weekday_number = (Get-Date -UFormat "%w")
+ weeknumber = (Get-Date -UFormat "%W")
+ year = $datetime.ToString("yyyy")
+ }
+
+ $ansible_facts += @{
+ ansible_date_time = $date
+ }
+}
+
+if($gather_subset.Contains('distribution')) {
+ $win32_os = Get-LazyCimInstance Win32_OperatingSystem
+ $product_type = switch($win32_os.ProductType) {
+ 1 { "workstation" }
+ 2 { "domain_controller" }
+ 3 { "server" }
+ default { "unknown" }
+ }
+
+ $installation_type = $null
+ $current_version_path = "HKLM:\SOFTWARE\Microsoft\Windows NT\CurrentVersion"
+ if (Test-Path -LiteralPath $current_version_path) {
+ $install_type_prop = Get-ItemProperty -LiteralPath $current_version_path -ErrorAction SilentlyContinue
+ $installation_type = [String]$install_type_prop.InstallationType
+ }
+
+ $ansible_facts += @{
+ ansible_distribution = $win32_os.Caption
+ ansible_distribution_version = $osversion.Version.ToString()
+ ansible_distribution_major_version = $osversion.Version.Major.ToString()
+ ansible_os_family = "Windows"
+ ansible_os_name = ($win32_os.Name.Split('|')[0]).Trim()
+ ansible_os_product_type = $product_type
+ ansible_os_installation_type = $installation_type
+ }
+}
+
+if($gather_subset.Contains('env')) {
+ $env_vars = @{ }
+ foreach ($item in Get-ChildItem Env:) {
+ $name = $item | Select-Object -ExpandProperty Name
+ # Powershell ConvertTo-Json fails if string ends with \
+ $value = ($item | Select-Object -ExpandProperty Value).TrimEnd("\")
+ $env_vars.Add($name, $value)
+ }
+
+ $ansible_facts += @{
+ ansible_env = $env_vars
+ }
+}
+
+if($gather_subset.Contains('facter')) {
+ # See if Facter is on the System Path
+ Try {
+ Get-Command facter -ErrorAction Stop > $null
+ $facter_installed = $true
+ } Catch {
+ $facter_installed = $false
+ }
+
+ # Get JSON from Facter, and parse it out.
+ if ($facter_installed) {
+ &facter -j | Tee-Object -Variable facter_output > $null
+ $facts = "$facter_output" | ConvertFrom-Json
+ ForEach($fact in $facts.PSObject.Properties) {
+ $fact_name = $fact.Name
+ $ansible_facts.Add("facter_$fact_name", $fact.Value)
+ }
+ }
+}
+
+if($gather_subset.Contains('interfaces')) {
+ $netcfg = Get-LazyCimInstance Win32_NetworkAdapterConfiguration
+ $ActiveNetcfg = @()
+ $ActiveNetcfg += $netcfg | Where-Object {$_.ipaddress -ne $null}
+
+ $namespaces = Get-LazyCimInstance __Namespace -namespace root
+ if ($namespaces | Where-Object { $_.Name -eq "StandardCimv" }) {
+ $net_adapters = Get-LazyCimInstance MSFT_NetAdapter -namespace Root\StandardCimv2
+ $guid_key = "InterfaceGUID"
+ $name_key = "Name"
+ } else {
+ $net_adapters = Get-LazyCimInstance Win32_NetworkAdapter
+ $guid_key = "GUID"
+ $name_key = "NetConnectionID"
+ }
+
+ $formattednetcfg = @()
+ foreach ($adapter in $ActiveNetcfg)
+ {
+ $thisadapter = @{
+ default_gateway = $null
+ connection_name = $null
+ dns_domain = $adapter.dnsdomain
+ interface_index = $adapter.InterfaceIndex
+ interface_name = $adapter.description
+ macaddress = $adapter.macaddress
+ }
+
+ if ($adapter.defaultIPGateway)
+ {
+ $thisadapter.default_gateway = $adapter.DefaultIPGateway[0].ToString()
+ }
+ $net_adapter = $net_adapters | Where-Object { $_.$guid_key -eq $adapter.SettingID }
+ if ($net_adapter) {
+ $thisadapter.connection_name = $net_adapter.$name_key
+ }
+
+ $formattednetcfg += $thisadapter
+ }
+
+ $ansible_facts += @{
+ ansible_interfaces = $formattednetcfg
+ }
+}
+
+if ($gather_subset.Contains("local") -and $null -ne $factpath) {
+ # Get any custom facts; results are updated in the
+ Get-CustomFacts -factpath $factpath
+}
+
+if($gather_subset.Contains('memory')) {
+ $win32_cs = Get-LazyCimInstance Win32_ComputerSystem
+ $win32_os = Get-LazyCimInstance Win32_OperatingSystem
+ $ansible_facts += @{
+ # Win32_PhysicalMemory is empty on some virtual platforms
+ ansible_memtotal_mb = ([math]::ceiling($win32_cs.TotalPhysicalMemory / 1024 / 1024))
+ ansible_memfree_mb = ([math]::ceiling($win32_os.FreePhysicalMemory / 1024))
+ ansible_swaptotal_mb = ([math]::round($win32_os.TotalSwapSpaceSize / 1024))
+ ansible_pagefiletotal_mb = ([math]::round($win32_os.SizeStoredInPagingFiles / 1024))
+ ansible_pagefilefree_mb = ([math]::round($win32_os.FreeSpaceInPagingFiles / 1024))
+ }
+}
+
+
+if($gather_subset.Contains('platform')) {
+ $win32_cs = Get-LazyCimInstance Win32_ComputerSystem
+ $win32_os = Get-LazyCimInstance Win32_OperatingSystem
+ $domain_suffix = $win32_cs.Domain.Substring($win32_cs.Workgroup.length)
+ $fqdn = $win32_cs.DNSHostname
+
+ if( $domain_suffix -ne "")
+ {
+ $fqdn = $win32_cs.DNSHostname + "." + $domain_suffix
+ }
+
+ try {
+ $ansible_reboot_pending = Get-PendingRebootStatus
+ } catch {
+ # fails for non-admin users, set to null in this case
+ $ansible_reboot_pending = $null
+ }
+
+ $ansible_facts += @{
+ ansible_architecture = $win32_os.OSArchitecture
+ ansible_domain = $domain_suffix
+ ansible_fqdn = $fqdn
+ ansible_hostname = $win32_cs.DNSHostname
+ ansible_netbios_name = $win32_cs.Name
+ ansible_kernel = $osversion.Version.ToString()
+ ansible_nodename = $fqdn
+ ansible_machine_id = Get-MachineSid
+ ansible_owner_contact = ([string] $win32_cs.PrimaryOwnerContact)
+ ansible_owner_name = ([string] $win32_cs.PrimaryOwnerName)
+ # FUTURE: should this live in its own subset?
+ ansible_reboot_pending = $ansible_reboot_pending
+ ansible_system = $osversion.Platform.ToString()
+ ansible_system_description = ([string] $win32_os.Description)
+ ansible_system_vendor = $win32_cs.Manufacturer
+ }
+}
+
+if($gather_subset.Contains('powershell_version')) {
+ $ansible_facts += @{
+ ansible_powershell_version = ($PSVersionTable.PSVersion.Major)
+ }
+}
+
+if($gather_subset.Contains('processor')) {
+ $win32_cs = Get-LazyCimInstance Win32_ComputerSystem
+ $win32_cpu = Get-LazyCimInstance Win32_Processor
+ if ($win32_cpu -is [array]) {
+ # multi-socket, pick first
+ $win32_cpu = $win32_cpu[0]
+ }
+
+ $cpu_list = @( )
+ for ($i=1; $i -le $win32_cs.NumberOfLogicalProcessors; $i++) {
+ $cpu_list += $win32_cpu.Manufacturer
+ $cpu_list += $win32_cpu.Name
+ }
+
+ $ansible_facts += @{
+ ansible_processor = $cpu_list
+ ansible_processor_cores = $win32_cpu.NumberOfCores
+ ansible_processor_count = $win32_cs.NumberOfProcessors
+ ansible_processor_threads_per_core = ($win32_cpu.NumberOfLogicalProcessors / $win32_cpu.NumberofCores)
+ ansible_processor_vcpus = $win32_cs.NumberOfLogicalProcessors
+ }
+}
+
+if($gather_subset.Contains('uptime')) {
+ $win32_os = Get-LazyCimInstance Win32_OperatingSystem
+ $ansible_facts += @{
+ ansible_lastboot = $win32_os.lastbootuptime.ToString("u")
+ ansible_uptime_seconds = $([System.Convert]::ToInt64($(Get-Date).Subtract($win32_os.lastbootuptime).TotalSeconds))
+ }
+}
+
+if($gather_subset.Contains('user')) {
+ $user = [Security.Principal.WindowsIdentity]::GetCurrent()
+ $ansible_facts += @{
+ ansible_user_dir = $env:userprofile
+ # Win32_UserAccount.FullName is probably the right thing here, but it can be expensive to get on large domains
+ ansible_user_gecos = ""
+ ansible_user_id = $env:username
+ ansible_user_sid = $user.User.Value
+ }
+}
+
+if($gather_subset.Contains('windows_domain')) {
+ $win32_cs = Get-LazyCimInstance Win32_ComputerSystem
+ $domain_roles = @{
+ 0 = "Stand-alone workstation"
+ 1 = "Member workstation"
+ 2 = "Stand-alone server"
+ 3 = "Member server"
+ 4 = "Backup domain controller"
+ 5 = "Primary domain controller"
+ }
+
+ $domain_role = $domain_roles.Get_Item([Int32]$win32_cs.DomainRole)
+
+ $ansible_facts += @{
+ ansible_windows_domain = $win32_cs.Domain
+ ansible_windows_domain_member = $win32_cs.PartOfDomain
+ ansible_windows_domain_role = $domain_role
+ }
+}
+
+if($gather_subset.Contains('winrm')) {
+
+ $winrm_https_listener_parent_paths = Get-ChildItem -Path WSMan:\localhost\Listener -Recurse -ErrorAction SilentlyContinue | `
+ Where-Object {$_.PSChildName -eq "Transport" -and $_.Value -eq "HTTPS"} | Select-Object PSParentPath
+ if ($winrm_https_listener_parent_paths -isnot [array]) {
+ $winrm_https_listener_parent_paths = @($winrm_https_listener_parent_paths)
+ }
+
+ $winrm_https_listener_paths = @()
+ foreach ($winrm_https_listener_parent_path in $winrm_https_listener_parent_paths) {
+ $winrm_https_listener_paths += $winrm_https_listener_parent_path.PSParentPath.Substring($winrm_https_listener_parent_path.PSParentPath.LastIndexOf("\"))
+ }
+
+ $https_listeners = @()
+ foreach ($winrm_https_listener_path in $winrm_https_listener_paths) {
+ $https_listeners += Get-ChildItem -Path "WSMan:\localhost\Listener$winrm_https_listener_path"
+ }
+
+ $winrm_cert_thumbprints = @()
+ foreach ($https_listener in $https_listeners) {
+ $winrm_cert_thumbprints += $https_listener | Where-Object {$_.Name -EQ "CertificateThumbprint" } | Select-Object Value
+ }
+
+ $winrm_cert_expiry = @()
+ foreach ($winrm_cert_thumbprint in $winrm_cert_thumbprints) {
+ Try {
+ $winrm_cert_expiry += Get-ChildItem -Path Cert:\LocalMachine\My | Where-Object Thumbprint -EQ $winrm_cert_thumbprint.Value.ToString().ToUpper() | Select-Object NotAfter
+ } Catch {
+ Add-Warning -obj $result -message "Error during certificate expiration retrieval: $($_.Exception.Message)"
+ }
+ }
+
+ $winrm_cert_expirations = $winrm_cert_expiry | Sort-Object NotAfter
+ if ($winrm_cert_expirations) {
+ # this fact was renamed from ansible_winrm_certificate_expires due to collision with ansible_winrm_X connection var pattern
+ $ansible_facts.Add("ansible_win_rm_certificate_expires", $winrm_cert_expirations[0].NotAfter.ToString("yyyy-MM-dd HH:mm:ss"))
+ }
+}
+
+if($gather_subset.Contains('virtual')) {
+ $machine_info = Get-LazyCimInstance Win32_ComputerSystem
+
+ switch ($machine_info.model) {
+ "Virtual Machine" {
+ $machine_type="Hyper-V"
+ $machine_role="guest"
+ }
+
+ "VMware Virtual Platform" {
+ $machine_type="VMware"
+ $machine_role="guest"
+ }
+
+ "VirtualBox" {
+ $machine_type="VirtualBox"
+ $machine_role="guest"
+ }
+
+ "HVM domU" {
+ $machine_type="Xen"
+ $machine_role="guest"
+ }
+
+ default {
+ $machine_type="NA"
+ $machine_role="NA"
+ }
+ }
+
+ $ansible_facts += @{
+ ansible_virtualization_role = $machine_role
+ ansible_virtualization_type = $machine_type
+ }
+}
+
+$result.ansible_facts += $ansible_facts
+
+Exit-Json $result
diff --git a/test/support/windows-integration/plugins/modules/slurp.ps1 b/test/support/windows-integration/plugins/modules/slurp.ps1
new file mode 100644
index 00000000..eb506c7c
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/slurp.ps1
@@ -0,0 +1,28 @@
+#!powershell
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+
+$params = Parse-Args $args -supports_check_mode $true;
+$src = Get-AnsibleParam -obj $params -name "src" -type "path" -aliases "path" -failifempty $true;
+
+$result = @{
+ changed = $false;
+}
+
+If (Test-Path -LiteralPath $src -PathType Leaf)
+{
+ $bytes = [System.IO.File]::ReadAllBytes($src);
+ $result.content = [System.Convert]::ToBase64String($bytes);
+ $result.encoding = "base64";
+ Exit-Json $result;
+}
+ElseIf (Test-Path -LiteralPath $src -PathType Container)
+{
+ Fail-Json $result "Path $src is a directory";
+}
+Else
+{
+ Fail-Json $result "Path $src is not found";
+}
diff --git a/test/support/windows-integration/plugins/modules/win_acl.ps1 b/test/support/windows-integration/plugins/modules/win_acl.ps1
new file mode 100644
index 00000000..e3c38130
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_acl.ps1
@@ -0,0 +1,225 @@
+#!powershell
+
+# Copyright: (c) 2015, Phil Schwartz <schwartzmx@gmail.com>
+# Copyright: (c) 2015, Trond Hindenes
+# Copyright: (c) 2015, Hans-Joachim Kliemeck <git@kliemeck.de>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+#Requires -Module Ansible.ModuleUtils.PrivilegeUtil
+#Requires -Module Ansible.ModuleUtils.SID
+
+$ErrorActionPreference = "Stop"
+
+# win_acl module (File/Resources Permission Additions/Removal)
+
+#Functions
+function Get-UserSID {
+ param(
+ [String]$AccountName
+ )
+
+ $userSID = $null
+ $searchAppPools = $false
+
+ if ($AccountName.Split("\").Count -gt 1) {
+ if ($AccountName.Split("\")[0] -eq "IIS APPPOOL") {
+ $searchAppPools = $true
+ $AccountName = $AccountName.Split("\")[1]
+ }
+ }
+
+ if ($searchAppPools) {
+ Import-Module -Name WebAdministration
+ $testIISPath = Test-Path -LiteralPath "IIS:"
+ if ($testIISPath) {
+ $appPoolObj = Get-ItemProperty -LiteralPath "IIS:\AppPools\$AccountName"
+ $userSID = $appPoolObj.applicationPoolSid
+ }
+ }
+ else {
+ $userSID = Convert-ToSID -account_name $AccountName
+ }
+
+ return $userSID
+}
+
+$params = Parse-Args $args
+
+Function SetPrivilegeTokens() {
+ # Set privilege tokens only if admin.
+ # Admins would have these privs or be able to set these privs in the UI Anyway
+
+ $adminRole=[System.Security.Principal.WindowsBuiltInRole]::Administrator
+ $myWindowsID=[System.Security.Principal.WindowsIdentity]::GetCurrent()
+ $myWindowsPrincipal=new-object System.Security.Principal.WindowsPrincipal($myWindowsID)
+
+
+ if ($myWindowsPrincipal.IsInRole($adminRole)) {
+ # Need to adjust token privs when executing Set-ACL in certain cases.
+ # e.g. d:\testdir is owned by group in which current user is not a member and no perms are inherited from d:\
+ # This also sets us up for setting the owner as a feature.
+ # See the following for details of each privilege
+ # https://msdn.microsoft.com/en-us/library/windows/desktop/bb530716(v=vs.85).aspx
+ $privileges = @(
+ "SeRestorePrivilege", # Grants all write access control to any file, regardless of ACL.
+ "SeBackupPrivilege", # Grants all read access control to any file, regardless of ACL.
+ "SeTakeOwnershipPrivilege" # Grants ability to take owernship of an object w/out being granted discretionary access
+ )
+ foreach ($privilege in $privileges) {
+ $state = Get-AnsiblePrivilege -Name $privilege
+ if ($state -eq $false) {
+ Set-AnsiblePrivilege -Name $privilege -Value $true
+ }
+ }
+ }
+}
+
+
+$result = @{
+ changed = $false
+}
+
+$path = Get-AnsibleParam -obj $params -name "path" -type "str" -failifempty $true
+$user = Get-AnsibleParam -obj $params -name "user" -type "str" -failifempty $true
+$rights = Get-AnsibleParam -obj $params -name "rights" -type "str" -failifempty $true
+
+$type = Get-AnsibleParam -obj $params -name "type" -type "str" -failifempty $true -validateset "allow","deny"
+$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "absent","present"
+
+$inherit = Get-AnsibleParam -obj $params -name "inherit" -type "str"
+$propagation = Get-AnsibleParam -obj $params -name "propagation" -type "str" -default "None" -validateset "InheritOnly","None","NoPropagateInherit"
+
+# We mount the HKCR, HKU, and HKCC registry hives so PS can access them.
+# Network paths have no qualifiers so we use -EA SilentlyContinue to ignore that
+$path_qualifier = Split-Path -Path $path -Qualifier -ErrorAction SilentlyContinue
+if ($path_qualifier -eq "HKCR:" -and (-not (Test-Path -LiteralPath HKCR:\))) {
+ New-PSDrive -Name HKCR -PSProvider Registry -Root HKEY_CLASSES_ROOT > $null
+}
+if ($path_qualifier -eq "HKU:" -and (-not (Test-Path -LiteralPath HKU:\))) {
+ New-PSDrive -Name HKU -PSProvider Registry -Root HKEY_USERS > $null
+}
+if ($path_qualifier -eq "HKCC:" -and (-not (Test-Path -LiteralPath HKCC:\))) {
+ New-PSDrive -Name HKCC -PSProvider Registry -Root HKEY_CURRENT_CONFIG > $null
+}
+
+If (-Not (Test-Path -LiteralPath $path)) {
+ Fail-Json -obj $result -message "$path file or directory does not exist on the host"
+}
+
+# Test that the user/group is resolvable on the local machine
+$sid = Get-UserSID -AccountName $user
+if (!$sid) {
+ Fail-Json -obj $result -message "$user is not a valid user or group on the host machine or domain"
+}
+
+If (Test-Path -LiteralPath $path -PathType Leaf) {
+ $inherit = "None"
+}
+ElseIf ($null -eq $inherit) {
+ $inherit = "ContainerInherit, ObjectInherit"
+}
+
+# Bug in Set-Acl, Get-Acl where -LiteralPath only works for the Registry provider if the location is in that root
+# qualifier. We also don't have a qualifier for a network path so only change if not null
+if ($null -ne $path_qualifier) {
+ Push-Location -LiteralPath $path_qualifier
+}
+
+Try {
+ SetPrivilegeTokens
+ $path_item = Get-Item -LiteralPath $path -Force
+ If ($path_item.PSProvider.Name -eq "Registry") {
+ $colRights = [System.Security.AccessControl.RegistryRights]$rights
+ }
+ Else {
+ $colRights = [System.Security.AccessControl.FileSystemRights]$rights
+ }
+
+ $InheritanceFlag = [System.Security.AccessControl.InheritanceFlags]$inherit
+ $PropagationFlag = [System.Security.AccessControl.PropagationFlags]$propagation
+
+ If ($type -eq "allow") {
+ $objType =[System.Security.AccessControl.AccessControlType]::Allow
+ }
+ Else {
+ $objType =[System.Security.AccessControl.AccessControlType]::Deny
+ }
+
+ $objUser = New-Object System.Security.Principal.SecurityIdentifier($sid)
+ If ($path_item.PSProvider.Name -eq "Registry") {
+ $objACE = New-Object System.Security.AccessControl.RegistryAccessRule ($objUser, $colRights, $InheritanceFlag, $PropagationFlag, $objType)
+ }
+ Else {
+ $objACE = New-Object System.Security.AccessControl.FileSystemAccessRule ($objUser, $colRights, $InheritanceFlag, $PropagationFlag, $objType)
+ }
+ $objACL = Get-ACL -LiteralPath $path
+
+ # Check if the ACE exists already in the objects ACL list
+ $match = $false
+
+ ForEach($rule in $objACL.GetAccessRules($true, $true, [System.Security.Principal.SecurityIdentifier])){
+
+ If ($path_item.PSProvider.Name -eq "Registry") {
+ If (($rule.RegistryRights -eq $objACE.RegistryRights) -And ($rule.AccessControlType -eq $objACE.AccessControlType) -And ($rule.IdentityReference -eq $objACE.IdentityReference) -And ($rule.IsInherited -eq $objACE.IsInherited) -And ($rule.InheritanceFlags -eq $objACE.InheritanceFlags) -And ($rule.PropagationFlags -eq $objACE.PropagationFlags)) {
+ $match = $true
+ Break
+ }
+ } else {
+ If (($rule.FileSystemRights -eq $objACE.FileSystemRights) -And ($rule.AccessControlType -eq $objACE.AccessControlType) -And ($rule.IdentityReference -eq $objACE.IdentityReference) -And ($rule.IsInherited -eq $objACE.IsInherited) -And ($rule.InheritanceFlags -eq $objACE.InheritanceFlags) -And ($rule.PropagationFlags -eq $objACE.PropagationFlags)) {
+ $match = $true
+ Break
+ }
+ }
+ }
+
+ If ($state -eq "present" -And $match -eq $false) {
+ Try {
+ $objACL.AddAccessRule($objACE)
+ If ($path_item.PSProvider.Name -eq "Registry") {
+ Set-ACL -LiteralPath $path -AclObject $objACL
+ } else {
+ (Get-Item -LiteralPath $path).SetAccessControl($objACL)
+ }
+ $result.changed = $true
+ }
+ Catch {
+ Fail-Json -obj $result -message "an exception occurred when adding the specified rule - $($_.Exception.Message)"
+ }
+ }
+ ElseIf ($state -eq "absent" -And $match -eq $true) {
+ Try {
+ $objACL.RemoveAccessRule($objACE)
+ If ($path_item.PSProvider.Name -eq "Registry") {
+ Set-ACL -LiteralPath $path -AclObject $objACL
+ } else {
+ (Get-Item -LiteralPath $path).SetAccessControl($objACL)
+ }
+ $result.changed = $true
+ }
+ Catch {
+ Fail-Json -obj $result -message "an exception occurred when removing the specified rule - $($_.Exception.Message)"
+ }
+ }
+ Else {
+ # A rule was attempting to be added but already exists
+ If ($match -eq $true) {
+ Exit-Json -obj $result -message "the specified rule already exists"
+ }
+ # A rule didn't exist that was trying to be removed
+ Else {
+ Exit-Json -obj $result -message "the specified rule does not exist"
+ }
+ }
+}
+Catch {
+ Fail-Json -obj $result -message "an error occurred when attempting to $state $rights permission(s) on $path for $user - $($_.Exception.Message)"
+}
+Finally {
+ # Make sure we revert the location stack to the original path just for cleanups sake
+ if ($null -ne $path_qualifier) {
+ Pop-Location
+ }
+}
+
+Exit-Json -obj $result
diff --git a/test/support/windows-integration/plugins/modules/win_acl.py b/test/support/windows-integration/plugins/modules/win_acl.py
new file mode 100644
index 00000000..14fbd82f
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_acl.py
@@ -0,0 +1,132 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2015, Phil Schwartz <schwartzmx@gmail.com>
+# Copyright: (c) 2015, Trond Hindenes
+# Copyright: (c) 2015, Hans-Joachim Kliemeck <git@kliemeck.de>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_acl
+version_added: "2.0"
+short_description: Set file/directory/registry permissions for a system user or group
+description:
+- Add or remove rights/permissions for a given user or group for the specified
+ file, folder, registry key or AppPool identifies.
+options:
+ path:
+ description:
+ - The path to the file or directory.
+ type: str
+ required: yes
+ user:
+ description:
+ - User or Group to add specified rights to act on src file/folder or
+ registry key.
+ type: str
+ required: yes
+ state:
+ description:
+ - Specify whether to add C(present) or remove C(absent) the specified access rule.
+ type: str
+ choices: [ absent, present ]
+ default: present
+ type:
+ description:
+ - Specify whether to allow or deny the rights specified.
+ type: str
+ required: yes
+ choices: [ allow, deny ]
+ rights:
+ description:
+ - The rights/permissions that are to be allowed/denied for the specified
+ user or group for the item at C(path).
+ - If C(path) is a file or directory, rights can be any right under MSDN
+ FileSystemRights U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.filesystemrights.aspx).
+ - If C(path) is a registry key, rights can be any right under MSDN
+ RegistryRights U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.registryrights.aspx).
+ type: str
+ required: yes
+ inherit:
+ description:
+ - Inherit flags on the ACL rules.
+ - Can be specified as a comma separated list, e.g. C(ContainerInherit),
+ C(ObjectInherit).
+ - For more information on the choices see MSDN InheritanceFlags enumeration
+ at U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.inheritanceflags.aspx).
+ - Defaults to C(ContainerInherit, ObjectInherit) for Directories.
+ type: str
+ choices: [ ContainerInherit, ObjectInherit ]
+ propagation:
+ description:
+ - Propagation flag on the ACL rules.
+ - For more information on the choices see MSDN PropagationFlags enumeration
+ at U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.propagationflags.aspx).
+ type: str
+ choices: [ InheritOnly, None, NoPropagateInherit ]
+ default: "None"
+notes:
+- If adding ACL's for AppPool identities (available since 2.3), the Windows
+ Feature "Web-Scripting-Tools" must be enabled.
+seealso:
+- module: win_acl_inheritance
+- module: win_file
+- module: win_owner
+- module: win_stat
+author:
+- Phil Schwartz (@schwartzmx)
+- Trond Hindenes (@trondhindenes)
+- Hans-Joachim Kliemeck (@h0nIg)
+'''
+
+EXAMPLES = r'''
+- name: Restrict write and execute access to User Fed-Phil
+ win_acl:
+ user: Fed-Phil
+ path: C:\Important\Executable.exe
+ type: deny
+ rights: ExecuteFile,Write
+
+- name: Add IIS_IUSRS allow rights
+ win_acl:
+ path: C:\inetpub\wwwroot\MySite
+ user: IIS_IUSRS
+ rights: FullControl
+ type: allow
+ state: present
+ inherit: ContainerInherit, ObjectInherit
+ propagation: 'None'
+
+- name: Set registry key right
+ win_acl:
+ path: HKCU:\Bovine\Key
+ user: BUILTIN\Users
+ rights: EnumerateSubKeys
+ type: allow
+ state: present
+ inherit: ContainerInherit, ObjectInherit
+ propagation: 'None'
+
+- name: Remove FullControl AccessRule for IIS_IUSRS
+ win_acl:
+ path: C:\inetpub\wwwroot\MySite
+ user: IIS_IUSRS
+ rights: FullControl
+ type: allow
+ state: absent
+ inherit: ContainerInherit, ObjectInherit
+ propagation: 'None'
+
+- name: Deny Intern
+ win_acl:
+ path: C:\Administrator\Documents
+ user: Intern
+ rights: Read,Write,Modify,FullControl,Delete
+ type: deny
+ state: present
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_certificate_store.ps1 b/test/support/windows-integration/plugins/modules/win_certificate_store.ps1
new file mode 100644
index 00000000..db984130
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_certificate_store.ps1
@@ -0,0 +1,260 @@
+#!powershell
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+
+$store_name_values = ([System.Security.Cryptography.X509Certificates.StoreName]).GetEnumValues() | ForEach-Object { $_.ToString() }
+$store_location_values = ([System.Security.Cryptography.X509Certificates.StoreLocation]).GetEnumValues() | ForEach-Object { $_.ToString() }
+
+$spec = @{
+ options = @{
+ state = @{ type = "str"; default = "present"; choices = "absent", "exported", "present" }
+ path = @{ type = "path" }
+ thumbprint = @{ type = "str" }
+ store_name = @{ type = "str"; default = "My"; choices = $store_name_values }
+ store_location = @{ type = "str"; default = "LocalMachine"; choices = $store_location_values }
+ password = @{ type = "str"; no_log = $true }
+ key_exportable = @{ type = "bool"; default = $true }
+ key_storage = @{ type = "str"; default = "default"; choices = "default", "machine", "user" }
+ file_type = @{ type = "str"; default = "der"; choices = "der", "pem", "pkcs12" }
+ }
+ required_if = @(
+ @("state", "absent", @("path", "thumbprint"), $true),
+ @("state", "exported", @("path", "thumbprint")),
+ @("state", "present", @("path"))
+ )
+ supports_check_mode = $true
+}
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec)
+
+Function Get-CertFile($module, $path, $password, $key_exportable, $key_storage) {
+ # parses a certificate file and returns X509Certificate2Collection
+ if (-not (Test-Path -LiteralPath $path -PathType Leaf)) {
+ $module.FailJson("File at '$path' either does not exist or is not a file")
+ }
+
+ # must set at least the PersistKeySet flag so that the PrivateKey
+ # is stored in a permanent container and not deleted once the handle
+ # is gone.
+ $store_flags = [System.Security.Cryptography.X509Certificates.X509KeyStorageFlags]::PersistKeySet
+
+ $key_storage = $key_storage.substring(0,1).ToUpper() + $key_storage.substring(1).ToLower()
+ $store_flags = $store_flags -bor [Enum]::Parse([System.Security.Cryptography.X509Certificates.X509KeyStorageFlags], "$($key_storage)KeySet")
+ if ($key_exportable) {
+ $store_flags = $store_flags -bor [System.Security.Cryptography.X509Certificates.X509KeyStorageFlags]::Exportable
+ }
+
+ # TODO: If I'm feeling adventurours, write code to parse PKCS#12 PEM encoded
+ # file as .NET does not have an easy way to import this
+ $certs = New-Object -TypeName System.Security.Cryptography.X509Certificates.X509Certificate2Collection
+
+ try {
+ $certs.Import($path, $password, $store_flags)
+ } catch {
+ $module.FailJson("Failed to load cert from file: $($_.Exception.Message)", $_)
+ }
+
+ return $certs
+}
+
+Function New-CertFile($module, $cert, $path, $type, $password) {
+ $content_type = switch ($type) {
+ "pem" { [System.Security.Cryptography.X509Certificates.X509ContentType]::Cert }
+ "der" { [System.Security.Cryptography.X509Certificates.X509ContentType]::Cert }
+ "pkcs12" { [System.Security.Cryptography.X509Certificates.X509ContentType]::Pkcs12 }
+ }
+ if ($type -eq "pkcs12") {
+ $missing_key = $false
+ if ($null -eq $cert.PrivateKey) {
+ $missing_key = $true
+ } elseif ($cert.PrivateKey.CspKeyContainerInfo.Exportable -eq $false) {
+ $missing_key = $true
+ }
+ if ($missing_key) {
+ $module.FailJson("Cannot export cert with key as PKCS12 when the key is not marked as exportable or not accessible by the current user")
+ }
+ }
+
+ if (Test-Path -LiteralPath $path) {
+ Remove-Item -LiteralPath $path -Force
+ $module.Result.changed = $true
+ }
+ try {
+ $cert_bytes = $cert.Export($content_type, $password)
+ } catch {
+ $module.FailJson("Failed to export certificate as bytes: $($_.Exception.Message)", $_)
+ }
+
+ # Need to manually handle a PEM file
+ if ($type -eq "pem") {
+ $cert_content = "-----BEGIN CERTIFICATE-----`r`n"
+ $base64_string = [System.Convert]::ToBase64String($cert_bytes, [System.Base64FormattingOptions]::InsertLineBreaks)
+ $cert_content += $base64_string
+ $cert_content += "`r`n-----END CERTIFICATE-----"
+ $file_encoding = [System.Text.Encoding]::ASCII
+ $cert_bytes = $file_encoding.GetBytes($cert_content)
+ } elseif ($type -eq "pkcs12") {
+ $module.Result.key_exported = $false
+ if ($null -ne $cert.PrivateKey) {
+ $module.Result.key_exportable = $cert.PrivateKey.CspKeyContainerInfo.Exportable
+ }
+ }
+
+ if (-not $module.CheckMode) {
+ try {
+ [System.IO.File]::WriteAllBytes($path, $cert_bytes)
+ } catch [System.ArgumentNullException] {
+ $module.FailJson("Failed to write cert to file, cert was null: $($_.Exception.Message)", $_)
+ } catch [System.IO.IOException] {
+ $module.FailJson("Failed to write cert to file due to IO Exception: $($_.Exception.Message)", $_)
+ } catch [System.UnauthorizedAccessException] {
+ $module.FailJson("Failed to write cert to file due to permissions: $($_.Exception.Message)", $_)
+ } catch {
+ $module.FailJson("Failed to write cert to file: $($_.Exception.Message)", $_)
+ }
+ }
+ $module.Result.changed = $true
+}
+
+Function Get-CertFileType($path, $password) {
+ $certs = New-Object -TypeName System.Security.Cryptography.X509Certificates.X509Certificate2Collection
+ try {
+ $certs.Import($path, $password, 0)
+ } catch [System.Security.Cryptography.CryptographicException] {
+ # the file is a pkcs12 we just had the wrong password
+ return "pkcs12"
+ } catch {
+ return "unknown"
+ }
+
+ $file_contents = Get-Content -LiteralPath $path -Raw
+ if ($file_contents.StartsWith("-----BEGIN CERTIFICATE-----")) {
+ return "pem"
+ } elseif ($file_contents.StartsWith("-----BEGIN PKCS7-----")) {
+ return "pkcs7-ascii"
+ } elseif ($certs.Count -gt 1) {
+ # multiple certs must be pkcs7
+ return "pkcs7-binary"
+ } elseif ($certs[0].HasPrivateKey) {
+ return "pkcs12"
+ } elseif ($path.EndsWith(".pfx") -or $path.EndsWith(".p12")) {
+ # no way to differenciate a pfx with a der file so we must rely on the
+ # extension
+ return "pkcs12"
+ } else {
+ return "der"
+ }
+}
+
+$state = $module.Params.state
+$path = $module.Params.path
+$thumbprint = $module.Params.thumbprint
+$store_name = [System.Security.Cryptography.X509Certificates.StoreName]"$($module.Params.store_name)"
+$store_location = [System.Security.Cryptography.X509Certificates.Storelocation]"$($module.Params.store_location)"
+$password = $module.Params.password
+$key_exportable = $module.Params.key_exportable
+$key_storage = $module.Params.key_storage
+$file_type = $module.Params.file_type
+
+$module.Result.thumbprints = @()
+
+$store = New-Object -TypeName System.Security.Cryptography.X509Certificates.X509Store -ArgumentList $store_name, $store_location
+try {
+ $store.Open([System.Security.Cryptography.X509Certificates.OpenFlags]::ReadWrite)
+} catch [System.Security.Cryptography.CryptographicException] {
+ $module.FailJson("Unable to open the store as it is not readable: $($_.Exception.Message)", $_)
+} catch [System.Security.SecurityException] {
+ $module.FailJson("Unable to open the store with the current permissions: $($_.Exception.Message)", $_)
+} catch {
+ $module.FailJson("Unable to open the store: $($_.Exception.Message)", $_)
+}
+$store_certificates = $store.Certificates
+
+try {
+ if ($state -eq "absent") {
+ $cert_thumbprints = @()
+
+ if ($null -ne $path) {
+ $certs = Get-CertFile -module $module -path $path -password $password -key_exportable $key_exportable -key_storage $key_storage
+ foreach ($cert in $certs) {
+ $cert_thumbprints += $cert.Thumbprint
+ }
+ } elseif ($null -ne $thumbprint) {
+ $cert_thumbprints += $thumbprint
+ }
+
+ foreach ($cert_thumbprint in $cert_thumbprints) {
+ $module.Result.thumbprints += $cert_thumbprint
+ $found_certs = $store_certificates.Find([System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, $cert_thumbprint, $false)
+ if ($found_certs.Count -gt 0) {
+ foreach ($found_cert in $found_certs) {
+ try {
+ if (-not $module.CheckMode) {
+ $store.Remove($found_cert)
+ }
+ } catch [System.Security.SecurityException] {
+ $module.FailJson("Unable to remove cert with thumbprint '$cert_thumbprint' with current permissions: $($_.Exception.Message)", $_)
+ } catch {
+ $module.FailJson("Unable to remove cert with thumbprint '$cert_thumbprint': $($_.Exception.Message)", $_)
+ }
+ $module.Result.changed = $true
+ }
+ }
+ }
+ } elseif ($state -eq "exported") {
+ # TODO: Add support for PKCS7 and exporting a cert chain
+ $module.Result.thumbprints += $thumbprint
+ $export = $true
+ if (Test-Path -LiteralPath $path -PathType Container) {
+ $module.FailJson("Cannot export cert to path '$path' as it is a directory")
+ } elseif (Test-Path -LiteralPath $path -PathType Leaf) {
+ $actual_cert_type = Get-CertFileType -path $path -password $password
+ if ($actual_cert_type -eq $file_type) {
+ try {
+ $certs = Get-CertFile -module $module -path $path -password $password -key_exportable $key_exportable -key_storage $key_storage
+ } catch {
+ # failed to load the file so we set the thumbprint to something
+ # that will fail validation
+ $certs = @{Thumbprint = $null}
+ }
+
+ if ($certs.Thumbprint -eq $thumbprint) {
+ $export = $false
+ }
+ }
+ }
+
+ if ($export) {
+ $found_certs = $store_certificates.Find([System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, $thumbprint, $false)
+ if ($found_certs.Count -ne 1) {
+ $module.FailJson("Found $($found_certs.Count) certs when only expecting 1")
+ }
+
+ New-CertFile -module $module -cert $found_certs -path $path -type $file_type -password $password
+ }
+ } else {
+ $certs = Get-CertFile -module $module -path $path -password $password -key_exportable $key_exportable -key_storage $key_storage
+ foreach ($cert in $certs) {
+ $module.Result.thumbprints += $cert.Thumbprint
+ $found_certs = $store_certificates.Find([System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, $cert.Thumbprint, $false)
+ if ($found_certs.Count -eq 0) {
+ try {
+ if (-not $module.CheckMode) {
+ $store.Add($cert)
+ }
+ } catch [System.Security.Cryptography.CryptographicException] {
+ $module.FailJson("Unable to import certificate with thumbprint '$($cert.Thumbprint)' with the current permissions: $($_.Exception.Message)", $_)
+ } catch {
+ $module.FailJson("Unable to import certificate with thumbprint '$($cert.Thumbprint)': $($_.Exception.Message)", $_)
+ }
+ $module.Result.changed = $true
+ }
+ }
+ }
+} finally {
+ $store.Close()
+}
+
+$module.ExitJson()
diff --git a/test/support/windows-integration/plugins/modules/win_certificate_store.py b/test/support/windows-integration/plugins/modules/win_certificate_store.py
new file mode 100644
index 00000000..dc617e33
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_certificate_store.py
@@ -0,0 +1,208 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_certificate_store
+version_added: '2.5'
+short_description: Manages the certificate store
+description:
+- Used to import/export and remove certificates and keys from the local
+ certificate store.
+- This module is not used to create certificates and will only manage existing
+ certs as a file or in the store.
+- It can be used to import PEM, DER, P7B, PKCS12 (PFX) certificates and export
+ PEM, DER and PKCS12 certificates.
+options:
+ state:
+ description:
+ - If C(present), will ensure that the certificate at I(path) is imported
+ into the certificate store specified.
+ - If C(absent), will ensure that the certificate specified by I(thumbprint)
+ or the thumbprint of the cert at I(path) is removed from the store
+ specified.
+ - If C(exported), will ensure the file at I(path) is a certificate
+ specified by I(thumbprint).
+ - When exporting a certificate, if I(path) is a directory then the module
+ will fail, otherwise the file will be replaced if needed.
+ type: str
+ choices: [ absent, exported, present ]
+ default: present
+ path:
+ description:
+ - The path to a certificate file.
+ - This is required when I(state) is C(present) or C(exported).
+ - When I(state) is C(absent) and I(thumbprint) is not specified, the
+ thumbprint is derived from the certificate at this path.
+ type: path
+ thumbprint:
+ description:
+ - The thumbprint as a hex string to either export or remove.
+ - See the examples for how to specify the thumbprint.
+ type: str
+ store_name:
+ description:
+ - The store name to use when importing a certificate or searching for a
+ certificate.
+ - "C(AddressBook): The X.509 certificate store for other users"
+ - "C(AuthRoot): The X.509 certificate store for third-party certificate authorities (CAs)"
+ - "C(CertificateAuthority): The X.509 certificate store for intermediate certificate authorities (CAs)"
+ - "C(Disallowed): The X.509 certificate store for revoked certificates"
+ - "C(My): The X.509 certificate store for personal certificates"
+ - "C(Root): The X.509 certificate store for trusted root certificate authorities (CAs)"
+ - "C(TrustedPeople): The X.509 certificate store for directly trusted people and resources"
+ - "C(TrustedPublisher): The X.509 certificate store for directly trusted publishers"
+ type: str
+ choices:
+ - AddressBook
+ - AuthRoot
+ - CertificateAuthority
+ - Disallowed
+ - My
+ - Root
+ - TrustedPeople
+ - TrustedPublisher
+ default: My
+ store_location:
+ description:
+ - The store location to use when importing a certificate or searching for a
+ certificate.
+ choices: [ CurrentUser, LocalMachine ]
+ default: LocalMachine
+ password:
+ description:
+ - The password of the pkcs12 certificate key.
+ - This is used when reading a pkcs12 certificate file or the password to
+ set when C(state=exported) and C(file_type=pkcs12).
+ - If the pkcs12 file has no password set or no password should be set on
+ the exported file, do not set this option.
+ type: str
+ key_exportable:
+ description:
+ - Whether to allow the private key to be exported.
+ - If C(no), then this module and other process will only be able to export
+ the certificate and the private key cannot be exported.
+ - Used when C(state=present) only.
+ type: bool
+ default: yes
+ key_storage:
+ description:
+ - Specifies where Windows will store the private key when it is imported.
+ - When set to C(default), the default option as set by Windows is used, typically C(user).
+ - When set to C(machine), the key is stored in a path accessible by various
+ users.
+ - When set to C(user), the key is stored in a path only accessible by the
+ current user.
+ - Used when C(state=present) only and cannot be changed once imported.
+ - See U(https://msdn.microsoft.com/en-us/library/system.security.cryptography.x509certificates.x509keystorageflags.aspx)
+ for more details.
+ type: str
+ choices: [ default, machine, user ]
+ default: default
+ file_type:
+ description:
+ - The file type to export the certificate as when C(state=exported).
+ - C(der) is a binary ASN.1 encoded file.
+ - C(pem) is a base64 encoded file of a der file in the OpenSSL form.
+ - C(pkcs12) (also known as pfx) is a binary container that contains both
+ the certificate and private key unlike the other options.
+ - When C(pkcs12) is set and the private key is not exportable or accessible
+ by the current user, it will throw an exception.
+ type: str
+ choices: [ der, pem, pkcs12 ]
+ default: der
+notes:
+- Some actions on PKCS12 certificates and keys may fail with the error
+ C(the specified network password is not correct), either use CredSSP or
+ Kerberos with credential delegation, or use C(become) to bypass these
+ restrictions.
+- The certificates must be located on the Windows host to be set with I(path).
+- When importing a certificate for usage in IIS, it is generally required
+ to use the C(machine) key_storage option, as both C(default) and C(user)
+ will make the private key unreadable to IIS APPPOOL identities and prevent
+ binding the certificate to the https endpoint.
+author:
+- Jordan Borean (@jborean93)
+'''
+
+EXAMPLES = r'''
+- name: Import a certificate
+ win_certificate_store:
+ path: C:\Temp\cert.pem
+ state: present
+
+- name: Import pfx certificate that is password protected
+ win_certificate_store:
+ path: C:\Temp\cert.pfx
+ state: present
+ password: VeryStrongPasswordHere!
+ become: yes
+ become_method: runas
+
+- name: Import pfx certificate without password and set private key as un-exportable
+ win_certificate_store:
+ path: C:\Temp\cert.pfx
+ state: present
+ key_exportable: no
+ # usually you don't set this here but it is for illustrative purposes
+ vars:
+ ansible_winrm_transport: credssp
+
+- name: Remove a certificate based on file thumbprint
+ win_certificate_store:
+ path: C:\Temp\cert.pem
+ state: absent
+
+- name: Remove a certificate based on thumbprint
+ win_certificate_store:
+ thumbprint: BD7AF104CF1872BDB518D95C9534EA941665FD27
+ state: absent
+
+- name: Remove certificate based on thumbprint is CurrentUser/TrustedPublishers store
+ win_certificate_store:
+ thumbprint: BD7AF104CF1872BDB518D95C9534EA941665FD27
+ state: absent
+ store_location: CurrentUser
+ store_name: TrustedPublisher
+
+- name: Export certificate as der encoded file
+ win_certificate_store:
+ path: C:\Temp\cert.cer
+ state: exported
+ file_type: der
+
+- name: Export certificate and key as pfx encoded file
+ win_certificate_store:
+ path: C:\Temp\cert.pfx
+ state: exported
+ file_type: pkcs12
+ password: AnotherStrongPass!
+ become: yes
+ become_method: runas
+ become_user: SYSTEM
+
+- name: Import certificate be used by IIS
+ win_certificate_store:
+ path: C:\Temp\cert.pfx
+ file_type: pkcs12
+ password: StrongPassword!
+ store_location: LocalMachine
+ key_storage: machine
+ state: present
+'''
+
+RETURN = r'''
+thumbprints:
+ description: A list of certificate thumbprints that were touched by the
+ module.
+ returned: success
+ type: list
+ sample: ["BC05633694E675449136679A658281F17A191087"]
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_command.ps1 b/test/support/windows-integration/plugins/modules/win_command.ps1
new file mode 100644
index 00000000..e2a30650
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_command.ps1
@@ -0,0 +1,78 @@
+#!powershell
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+#Requires -Module Ansible.ModuleUtils.CommandUtil
+#Requires -Module Ansible.ModuleUtils.FileUtil
+
+# TODO: add check mode support
+
+Set-StrictMode -Version 2
+$ErrorActionPreference = 'Stop'
+
+$params = Parse-Args $args -supports_check_mode $false
+
+$raw_command_line = Get-AnsibleParam -obj $params -name "_raw_params" -type "str" -failifempty $true
+$chdir = Get-AnsibleParam -obj $params -name "chdir" -type "path"
+$creates = Get-AnsibleParam -obj $params -name "creates" -type "path"
+$removes = Get-AnsibleParam -obj $params -name "removes" -type "path"
+$stdin = Get-AnsibleParam -obj $params -name "stdin" -type "str"
+$output_encoding_override = Get-AnsibleParam -obj $params -name "output_encoding_override" -type "str"
+
+$raw_command_line = $raw_command_line.Trim()
+
+$result = @{
+ changed = $true
+ cmd = $raw_command_line
+}
+
+if ($creates -and $(Test-AnsiblePath -Path $creates)) {
+ Exit-Json @{msg="skipped, since $creates exists";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0}
+}
+
+if ($removes -and -not $(Test-AnsiblePath -Path $removes)) {
+ Exit-Json @{msg="skipped, since $removes does not exist";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0}
+}
+
+$command_args = @{
+ command = $raw_command_line
+}
+if ($chdir) {
+ $command_args['working_directory'] = $chdir
+}
+if ($stdin) {
+ $command_args['stdin'] = $stdin
+}
+if ($output_encoding_override) {
+ $command_args['output_encoding_override'] = $output_encoding_override
+}
+
+$start_datetime = [DateTime]::UtcNow
+try {
+ $command_result = Run-Command @command_args
+} catch {
+ $result.changed = $false
+ try {
+ $result.rc = $_.Exception.NativeErrorCode
+ } catch {
+ $result.rc = 2
+ }
+ Fail-Json -obj $result -message $_.Exception.Message
+}
+
+$result.stdout = $command_result.stdout
+$result.stderr = $command_result.stderr
+$result.rc = $command_result.rc
+
+$end_datetime = [DateTime]::UtcNow
+$result.start = $start_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff")
+$result.end = $end_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff")
+$result.delta = $($end_datetime - $start_datetime).ToString("h\:mm\:ss\.ffffff")
+
+If ($result.rc -ne 0) {
+ Fail-Json -obj $result -message "non-zero return code"
+}
+
+Exit-Json $result
diff --git a/test/support/windows-integration/plugins/modules/win_command.py b/test/support/windows-integration/plugins/modules/win_command.py
new file mode 100644
index 00000000..508419b2
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_command.py
@@ -0,0 +1,136 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2016, Ansible, inc
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_command
+short_description: Executes a command on a remote Windows node
+version_added: 2.2
+description:
+ - The C(win_command) module takes the command name followed by a list of space-delimited arguments.
+ - The given command will be executed on all selected nodes. It will not be
+ processed through the shell, so variables like C($env:HOME) and operations
+ like C("<"), C(">"), C("|"), and C(";") will not work (use the M(win_shell)
+ module if you need these features).
+ - For non-Windows targets, use the M(command) module instead.
+options:
+ free_form:
+ description:
+ - The C(win_command) module takes a free form command to run.
+ - There is no parameter actually named 'free form'. See the examples!
+ type: str
+ required: yes
+ creates:
+ description:
+ - A path or path filter pattern; when the referenced path exists on the target host, the task will be skipped.
+ type: path
+ removes:
+ description:
+ - A path or path filter pattern; when the referenced path B(does not) exist on the target host, the task will be skipped.
+ type: path
+ chdir:
+ description:
+ - Set the specified path as the current working directory before executing a command.
+ type: path
+ stdin:
+ description:
+ - Set the stdin of the command directly to the specified value.
+ type: str
+ version_added: '2.5'
+ output_encoding_override:
+ description:
+ - This option overrides the encoding of stdout/stderr output.
+ - You can use this option when you need to run a command which ignore the console's codepage.
+ - You should only need to use this option in very rare circumstances.
+ - This value can be any valid encoding C(Name) based on the output of C([System.Text.Encoding]::GetEncodings()).
+ See U(https://docs.microsoft.com/dotnet/api/system.text.encoding.getencodings).
+ type: str
+ version_added: '2.10'
+notes:
+ - If you want to run a command through a shell (say you are using C(<),
+ C(>), C(|), etc), you actually want the M(win_shell) module instead. The
+ C(win_command) module is much more secure as it's not affected by the user's
+ environment.
+ - C(creates), C(removes), and C(chdir) can be specified after the command. For instance, if you only want to run a command if a certain file does not
+ exist, use this.
+seealso:
+- module: command
+- module: psexec
+- module: raw
+- module: win_psexec
+- module: win_shell
+author:
+ - Matt Davis (@nitzmahone)
+'''
+
+EXAMPLES = r'''
+- name: Save the result of 'whoami' in 'whoami_out'
+ win_command: whoami
+ register: whoami_out
+
+- name: Run command that only runs if folder exists and runs from a specific folder
+ win_command: wbadmin -backupTarget:C:\backup\
+ args:
+ chdir: C:\somedir\
+ creates: C:\backup\
+
+- name: Run an executable and send data to the stdin for the executable
+ win_command: powershell.exe -
+ args:
+ stdin: Write-Host test
+'''
+
+RETURN = r'''
+msg:
+ description: changed
+ returned: always
+ type: bool
+ sample: true
+start:
+ description: The command execution start time
+ returned: always
+ type: str
+ sample: '2016-02-25 09:18:26.429568'
+end:
+ description: The command execution end time
+ returned: always
+ type: str
+ sample: '2016-02-25 09:18:26.755339'
+delta:
+ description: The command execution delta time
+ returned: always
+ type: str
+ sample: '0:00:00.325771'
+stdout:
+ description: The command standard output
+ returned: always
+ type: str
+ sample: 'Clustering node rabbit@slave1 with rabbit@master ...'
+stderr:
+ description: The command standard error
+ returned: always
+ type: str
+ sample: 'ls: cannot access foo: No such file or directory'
+cmd:
+ description: The command executed by the task
+ returned: always
+ type: str
+ sample: 'rabbitmqctl join_cluster rabbit@master'
+rc:
+ description: The command return code (0 means success)
+ returned: always
+ type: int
+ sample: 0
+stdout_lines:
+ description: The command standard output split in lines
+ returned: always
+ type: list
+ sample: [u'Clustering node rabbit@slave1 with rabbit@master ...']
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_copy.ps1 b/test/support/windows-integration/plugins/modules/win_copy.ps1
new file mode 100644
index 00000000..6a26ee72
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_copy.ps1
@@ -0,0 +1,403 @@
+#!powershell
+
+# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk>
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+#Requires -Module Ansible.ModuleUtils.Backup
+
+$ErrorActionPreference = 'Stop'
+
+$params = Parse-Args -arguments $args -supports_check_mode $true
+$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false
+$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false
+
+# there are 4 modes to win_copy which are driven by the action plugins:
+# explode: src is a zip file which needs to be extracted to dest, for use with multiple files
+# query: win_copy action plugin wants to get the state of remote files to check whether it needs to send them
+# remote: all copy action is happening remotely (remote_src=True)
+# single: a single file has been copied, also used with template
+$copy_mode = Get-AnsibleParam -obj $params -name "_copy_mode" -type "str" -default "single" -validateset "explode","query","remote","single"
+
+# used in explode, remote and single mode
+$src = Get-AnsibleParam -obj $params -name "src" -type "path" -failifempty ($copy_mode -in @("explode","process","single"))
+$dest = Get-AnsibleParam -obj $params -name "dest" -type "path" -failifempty $true
+$backup = Get-AnsibleParam -obj $params -name "backup" -type "bool" -default $false
+
+# used in single mode
+$original_basename = Get-AnsibleParam -obj $params -name "_original_basename" -type "str"
+
+# used in query and remote mode
+$force = Get-AnsibleParam -obj $params -name "force" -type "bool" -default $true
+
+# used in query mode, contains the local files/directories/symlinks that are to be copied
+$files = Get-AnsibleParam -obj $params -name "files" -type "list"
+$directories = Get-AnsibleParam -obj $params -name "directories" -type "list"
+
+$result = @{
+ changed = $false
+}
+
+if ($diff_mode) {
+ $result.diff = @{}
+}
+
+Function Copy-File($source, $dest) {
+ $diff = ""
+ $copy_file = $false
+ $source_checksum = $null
+ if ($force) {
+ $source_checksum = Get-FileChecksum -path $source
+ }
+
+ if (Test-Path -LiteralPath $dest -PathType Container) {
+ Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': dest is already a folder"
+ } elseif (Test-Path -LiteralPath $dest -PathType Leaf) {
+ if ($force) {
+ $target_checksum = Get-FileChecksum -path $dest
+ if ($source_checksum -ne $target_checksum) {
+ $copy_file = $true
+ }
+ }
+ } else {
+ $copy_file = $true
+ }
+
+ if ($copy_file) {
+ $file_dir = [System.IO.Path]::GetDirectoryName($dest)
+ # validate the parent dir is not a file and that it exists
+ if (Test-Path -LiteralPath $file_dir -PathType Leaf) {
+ Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': object at dest parent dir is not a folder"
+ } elseif (-not (Test-Path -LiteralPath $file_dir)) {
+ # directory doesn't exist, need to create
+ New-Item -Path $file_dir -ItemType Directory -WhatIf:$check_mode | Out-Null
+ $diff += "+$file_dir\`n"
+ }
+
+ if ($backup) {
+ $result.backup_file = Backup-File -path $dest -WhatIf:$check_mode
+ }
+
+ if (Test-Path -LiteralPath $dest -PathType Leaf) {
+ Remove-Item -LiteralPath $dest -Force -Recurse -WhatIf:$check_mode | Out-Null
+ $diff += "-$dest`n"
+ }
+
+ if (-not $check_mode) {
+ # cannot run with -WhatIf:$check_mode as if the parent dir didn't
+ # exist and was created above would still not exist in check mode
+ Copy-Item -LiteralPath $source -Destination $dest -Force | Out-Null
+ }
+ $diff += "+$dest`n"
+
+ $result.changed = $true
+ }
+
+ # ugly but to save us from running the checksum twice, let's return it for
+ # the main code to add it to $result
+ return ,@{ diff = $diff; checksum = $source_checksum }
+}
+
+Function Copy-Folder($source, $dest) {
+ $diff = ""
+
+ if (-not (Test-Path -LiteralPath $dest -PathType Container)) {
+ $parent_dir = [System.IO.Path]::GetDirectoryName($dest)
+ if (Test-Path -LiteralPath $parent_dir -PathType Leaf) {
+ Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': object at dest parent dir is not a folder"
+ }
+ if (Test-Path -LiteralPath $dest -PathType Leaf) {
+ Fail-Json -obj $result -message "cannot copy folder from '$source' to '$dest': dest is already a file"
+ }
+
+ New-Item -Path $dest -ItemType Container -WhatIf:$check_mode | Out-Null
+ $diff += "+$dest\`n"
+ $result.changed = $true
+ }
+
+ $child_items = Get-ChildItem -LiteralPath $source -Force
+ foreach ($child_item in $child_items) {
+ $dest_child_path = Join-Path -Path $dest -ChildPath $child_item.Name
+ if ($child_item.PSIsContainer) {
+ $diff += (Copy-Folder -source $child_item.Fullname -dest $dest_child_path)
+ } else {
+ $diff += (Copy-File -source $child_item.Fullname -dest $dest_child_path).diff
+ }
+ }
+
+ return $diff
+}
+
+Function Get-FileSize($path) {
+ $file = Get-Item -LiteralPath $path -Force
+ if ($file.PSIsContainer) {
+ $size = (Get-ChildItem -Literalpath $file.FullName -Recurse -Force | `
+ Where-Object { $_.PSObject.Properties.Name -contains 'Length' } | `
+ Measure-Object -Property Length -Sum).Sum
+ if ($null -eq $size) {
+ $size = 0
+ }
+ } else {
+ $size = $file.Length
+ }
+
+ $size
+}
+
+Function Extract-Zip($src, $dest) {
+ $archive = [System.IO.Compression.ZipFile]::Open($src, [System.IO.Compression.ZipArchiveMode]::Read, [System.Text.Encoding]::UTF8)
+ foreach ($entry in $archive.Entries) {
+ $archive_name = $entry.FullName
+
+ # FullName may be appended with / or \, determine if it is padded and remove it
+ $padding_length = $archive_name.Length % 4
+ if ($padding_length -eq 0) {
+ $is_dir = $false
+ $base64_name = $archive_name
+ } elseif ($padding_length -eq 1) {
+ $is_dir = $true
+ if ($archive_name.EndsWith("/") -or $archive_name.EndsWith("`\")) {
+ $base64_name = $archive_name.Substring(0, $archive_name.Length - 1)
+ } else {
+ throw "invalid base64 archive name '$archive_name'"
+ }
+ } else {
+ throw "invalid base64 length '$archive_name'"
+ }
+
+ # to handle unicode character, win_copy action plugin has encoded the filename
+ $decoded_archive_name = [System.Text.Encoding]::UTF8.GetString([System.Convert]::FromBase64String($base64_name))
+ # re-add the / to the entry full name if it was a directory
+ if ($is_dir) {
+ $decoded_archive_name = "$decoded_archive_name/"
+ }
+ $entry_target_path = [System.IO.Path]::Combine($dest, $decoded_archive_name)
+ $entry_dir = [System.IO.Path]::GetDirectoryName($entry_target_path)
+
+ if (-not (Test-Path -LiteralPath $entry_dir)) {
+ New-Item -Path $entry_dir -ItemType Directory -WhatIf:$check_mode | Out-Null
+ }
+
+ if ($is_dir -eq $false) {
+ if (-not $check_mode) {
+ [System.IO.Compression.ZipFileExtensions]::ExtractToFile($entry, $entry_target_path, $true)
+ }
+ }
+ }
+ $archive.Dispose() # release the handle of the zip file
+}
+
+Function Extract-ZipLegacy($src, $dest) {
+ if (-not (Test-Path -LiteralPath $dest)) {
+ New-Item -Path $dest -ItemType Directory -WhatIf:$check_mode | Out-Null
+ }
+ $shell = New-Object -ComObject Shell.Application
+ $zip = $shell.NameSpace($src)
+ $dest_path = $shell.NameSpace($dest)
+
+ foreach ($entry in $zip.Items()) {
+ $is_dir = $entry.IsFolder
+ $encoded_archive_entry = $entry.Name
+ # to handle unicode character, win_copy action plugin has encoded the filename
+ $decoded_archive_entry = [System.Text.Encoding]::UTF8.GetString([System.Convert]::FromBase64String($encoded_archive_entry))
+ if ($is_dir) {
+ $decoded_archive_entry = "$decoded_archive_entry/"
+ }
+
+ $entry_target_path = [System.IO.Path]::Combine($dest, $decoded_archive_entry)
+ $entry_dir = [System.IO.Path]::GetDirectoryName($entry_target_path)
+
+ if (-not (Test-Path -LiteralPath $entry_dir)) {
+ New-Item -Path $entry_dir -ItemType Directory -WhatIf:$check_mode | Out-Null
+ }
+
+ if ($is_dir -eq $false -and (-not $check_mode)) {
+ # https://msdn.microsoft.com/en-us/library/windows/desktop/bb787866.aspx
+ # From Folder.CopyHere documentation, 1044 means:
+ # - 1024: do not display a user interface if an error occurs
+ # - 16: respond with "yes to all" for any dialog box that is displayed
+ # - 4: do not display a progress dialog box
+ $dest_path.CopyHere($entry, 1044)
+
+ # once file is extraced, we need to rename it with non base64 name
+ $combined_encoded_path = [System.IO.Path]::Combine($dest, $encoded_archive_entry)
+ Move-Item -LiteralPath $combined_encoded_path -Destination $entry_target_path -Force | Out-Null
+ }
+ }
+}
+
+if ($copy_mode -eq "query") {
+ # we only return a list of files/directories that need to be copied over
+ # the source of the local file will be the key used
+ $changed_files = @()
+ $changed_directories = @()
+ $changed_symlinks = @()
+
+ foreach ($file in $files) {
+ $filename = $file.dest
+ $local_checksum = $file.checksum
+
+ $filepath = Join-Path -Path $dest -ChildPath $filename
+ if (Test-Path -LiteralPath $filepath -PathType Leaf) {
+ if ($force) {
+ $checksum = Get-FileChecksum -path $filepath
+ if ($checksum -ne $local_checksum) {
+ $changed_files += $file
+ }
+ }
+ } elseif (Test-Path -LiteralPath $filepath -PathType Container) {
+ Fail-Json -obj $result -message "cannot copy file to dest '$filepath': object at path is already a directory"
+ } else {
+ $changed_files += $file
+ }
+ }
+
+ foreach ($directory in $directories) {
+ $dirname = $directory.dest
+
+ $dirpath = Join-Path -Path $dest -ChildPath $dirname
+ $parent_dir = [System.IO.Path]::GetDirectoryName($dirpath)
+ if (Test-Path -LiteralPath $parent_dir -PathType Leaf) {
+ Fail-Json -obj $result -message "cannot copy folder to dest '$dirpath': object at parent directory path is already a file"
+ }
+ if (Test-Path -LiteralPath $dirpath -PathType Leaf) {
+ Fail-Json -obj $result -message "cannot copy folder to dest '$dirpath': object at path is already a file"
+ } elseif (-not (Test-Path -LiteralPath $dirpath -PathType Container)) {
+ $changed_directories += $directory
+ }
+ }
+
+ # TODO: Handle symlinks
+
+ $result.files = $changed_files
+ $result.directories = $changed_directories
+ $result.symlinks = $changed_symlinks
+} elseif ($copy_mode -eq "explode") {
+ # a single zip file containing the files and directories needs to be
+ # expanded this will always result in a change as the calculation is done
+ # on the win_copy action plugin and is only run if a change needs to occur
+ if (-not (Test-Path -LiteralPath $src -PathType Leaf)) {
+ Fail-Json -obj $result -message "Cannot expand src zip file: '$src' as it does not exist"
+ }
+
+ # Detect if the PS zip assemblies are available or whether to use Shell
+ $use_legacy = $false
+ try {
+ Add-Type -AssemblyName System.IO.Compression.FileSystem | Out-Null
+ Add-Type -AssemblyName System.IO.Compression | Out-Null
+ } catch {
+ $use_legacy = $true
+ }
+ if ($use_legacy) {
+ Extract-ZipLegacy -src $src -dest $dest
+ } else {
+ Extract-Zip -src $src -dest $dest
+ }
+
+ $result.changed = $true
+} elseif ($copy_mode -eq "remote") {
+ # all copy actions are happening on the remote side (windows host), need
+ # too copy source and dest using PS code
+ $result.src = $src
+ $result.dest = $dest
+
+ if (-not (Test-Path -LiteralPath $src)) {
+ Fail-Json -obj $result -message "Cannot copy src file: '$src' as it does not exist"
+ }
+
+ if (Test-Path -LiteralPath $src -PathType Container) {
+ # we are copying a directory or the contents of a directory
+ $result.operation = 'folder_copy'
+ if ($src.EndsWith("/") -or $src.EndsWith("`\")) {
+ # copying the folder's contents to dest
+ $diff = ""
+ $child_files = Get-ChildItem -LiteralPath $src -Force
+ foreach ($child_file in $child_files) {
+ $dest_child_path = Join-Path -Path $dest -ChildPath $child_file.Name
+ if ($child_file.PSIsContainer) {
+ $diff += Copy-Folder -source $child_file.FullName -dest $dest_child_path
+ } else {
+ $diff += (Copy-File -source $child_file.FullName -dest $dest_child_path).diff
+ }
+ }
+ } else {
+ # copying the folder and it's contents to dest
+ $dest = Join-Path -Path $dest -ChildPath (Get-Item -LiteralPath $src -Force).Name
+ $result.dest = $dest
+ $diff = Copy-Folder -source $src -dest $dest
+ }
+ } else {
+ # we are just copying a single file to dest
+ $result.operation = 'file_copy'
+
+ $source_basename = (Get-Item -LiteralPath $src -Force).Name
+ $result.original_basename = $source_basename
+
+ if ($dest.EndsWith("/") -or $dest.EndsWith("`\")) {
+ $dest = Join-Path -Path $dest -ChildPath (Get-Item -LiteralPath $src -Force).Name
+ $result.dest = $dest
+ } else {
+ # check if the parent dir exists, this is only done if src is a
+ # file and dest if the path to a file (doesn't end with \ or /)
+ $parent_dir = Split-Path -LiteralPath $dest
+ if (Test-Path -LiteralPath $parent_dir -PathType Leaf) {
+ Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file"
+ } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) {
+ Fail-Json -obj $result -message "Destination directory '$parent_dir' does not exist"
+ }
+ }
+ $copy_result = Copy-File -source $src -dest $dest
+ $diff = $copy_result.diff
+ $result.checksum = $copy_result.checksum
+ }
+
+ # the file might not exist if running in check mode
+ if (-not $check_mode -or (Test-Path -LiteralPath $dest -PathType Leaf)) {
+ $result.size = Get-FileSize -path $dest
+ } else {
+ $result.size = $null
+ }
+ if ($diff_mode) {
+ $result.diff.prepared = $diff
+ }
+} elseif ($copy_mode -eq "single") {
+ # a single file is located in src and we need to copy to dest, this will
+ # always result in a change as the calculation is done on the Ansible side
+ # before this is run. This should also never run in check mode
+ if (-not (Test-Path -LiteralPath $src -PathType Leaf)) {
+ Fail-Json -obj $result -message "Cannot copy src file: '$src' as it does not exist"
+ }
+
+ # the dest parameter is a directory, we need to append original_basename
+ if ($dest.EndsWith("/") -or $dest.EndsWith("`\") -or (Test-Path -LiteralPath $dest -PathType Container)) {
+ $remote_dest = Join-Path -Path $dest -ChildPath $original_basename
+ $parent_dir = Split-Path -LiteralPath $remote_dest
+
+ # when dest ends with /, we need to create the destination directories
+ if (Test-Path -LiteralPath $parent_dir -PathType Leaf) {
+ Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file"
+ } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) {
+ New-Item -Path $parent_dir -ItemType Directory | Out-Null
+ }
+ } else {
+ $remote_dest = $dest
+ $parent_dir = Split-Path -LiteralPath $remote_dest
+
+ # check if the dest parent dirs exist, need to fail if they don't
+ if (Test-Path -LiteralPath $parent_dir -PathType Leaf) {
+ Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file"
+ } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) {
+ Fail-Json -obj $result -message "Destination directory '$parent_dir' does not exist"
+ }
+ }
+
+ if ($backup) {
+ $result.backup_file = Backup-File -path $remote_dest -WhatIf:$check_mode
+ }
+
+ Copy-Item -LiteralPath $src -Destination $remote_dest -Force | Out-Null
+ $result.changed = $true
+}
+
+Exit-Json -obj $result
diff --git a/test/support/windows-integration/plugins/modules/win_copy.py b/test/support/windows-integration/plugins/modules/win_copy.py
new file mode 100644
index 00000000..a55f4c65
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_copy.py
@@ -0,0 +1,207 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk>
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_copy
+version_added: '1.9.2'
+short_description: Copies files to remote locations on windows hosts
+description:
+- The C(win_copy) module copies a file on the local box to remote windows locations.
+- For non-Windows targets, use the M(copy) module instead.
+options:
+ content:
+ description:
+ - When used instead of C(src), sets the contents of a file directly to the
+ specified value.
+ - This is for simple values, for anything complex or with formatting please
+ switch to the M(template) module.
+ type: str
+ version_added: '2.3'
+ decrypt:
+ description:
+ - This option controls the autodecryption of source files using vault.
+ type: bool
+ default: yes
+ version_added: '2.5'
+ dest:
+ description:
+ - Remote absolute path where the file should be copied to.
+ - If C(src) is a directory, this must be a directory too.
+ - Use \ for path separators or \\ when in "double quotes".
+ - If C(dest) ends with \ then source or the contents of source will be
+ copied to the directory without renaming.
+ - If C(dest) is a nonexistent path, it will only be created if C(dest) ends
+ with "/" or "\", or C(src) is a directory.
+ - If C(src) and C(dest) are files and if the parent directory of C(dest)
+ doesn't exist, then the task will fail.
+ type: path
+ required: yes
+ backup:
+ description:
+ - Determine whether a backup should be created.
+ - When set to C(yes), create a backup file including the timestamp information
+ so you can get the original file back if you somehow clobbered it incorrectly.
+ - No backup is taken when C(remote_src=False) and multiple files are being
+ copied.
+ type: bool
+ default: no
+ version_added: '2.8'
+ force:
+ description:
+ - If set to C(yes), the file will only be transferred if the content
+ is different than destination.
+ - If set to C(no), the file will only be transferred if the
+ destination does not exist.
+ - If set to C(no), no checksuming of the content is performed which can
+ help improve performance on larger files.
+ type: bool
+ default: yes
+ version_added: '2.3'
+ local_follow:
+ description:
+ - This flag indicates that filesystem links in the source tree, if they
+ exist, should be followed.
+ type: bool
+ default: yes
+ version_added: '2.4'
+ remote_src:
+ description:
+ - If C(no), it will search for src at originating/master machine.
+ - If C(yes), it will go to the remote/target machine for the src.
+ type: bool
+ default: no
+ version_added: '2.3'
+ src:
+ description:
+ - Local path to a file to copy to the remote server; can be absolute or
+ relative.
+ - If path is a directory, it is copied (including the source folder name)
+ recursively to C(dest).
+ - If path is a directory and ends with "/", only the inside contents of
+ that directory are copied to the destination. Otherwise, if it does not
+ end with "/", the directory itself with all contents is copied.
+ - If path is a file and dest ends with "\", the file is copied to the
+ folder with the same filename.
+ - Required unless using C(content).
+ type: path
+notes:
+- Currently win_copy does not support copying symbolic links from both local to
+ remote and remote to remote.
+- It is recommended that backslashes C(\) are used instead of C(/) when dealing
+ with remote paths.
+- Because win_copy runs over WinRM, it is not a very efficient transfer
+ mechanism. If sending large files consider hosting them on a web service and
+ using M(win_get_url) instead.
+seealso:
+- module: assemble
+- module: copy
+- module: win_get_url
+- module: win_robocopy
+author:
+- Jon Hawkesworth (@jhawkesworth)
+- Jordan Borean (@jborean93)
+'''
+
+EXAMPLES = r'''
+- name: Copy a single file
+ win_copy:
+ src: /srv/myfiles/foo.conf
+ dest: C:\Temp\renamed-foo.conf
+
+- name: Copy a single file, but keep a backup
+ win_copy:
+ src: /srv/myfiles/foo.conf
+ dest: C:\Temp\renamed-foo.conf
+ backup: yes
+
+- name: Copy a single file keeping the filename
+ win_copy:
+ src: /src/myfiles/foo.conf
+ dest: C:\Temp\
+
+- name: Copy folder to C:\Temp (results in C:\Temp\temp_files)
+ win_copy:
+ src: files/temp_files
+ dest: C:\Temp
+
+- name: Copy folder contents recursively
+ win_copy:
+ src: files/temp_files/
+ dest: C:\Temp
+
+- name: Copy a single file where the source is on the remote host
+ win_copy:
+ src: C:\Temp\foo.txt
+ dest: C:\ansible\foo.txt
+ remote_src: yes
+
+- name: Copy a folder recursively where the source is on the remote host
+ win_copy:
+ src: C:\Temp
+ dest: C:\ansible
+ remote_src: yes
+
+- name: Set the contents of a file
+ win_copy:
+ content: abc123
+ dest: C:\Temp\foo.txt
+
+- name: Copy a single file as another user
+ win_copy:
+ src: NuGet.config
+ dest: '%AppData%\NuGet\NuGet.config'
+ vars:
+ ansible_become_user: user
+ ansible_become_password: pass
+ # The tmp dir must be set when using win_copy as another user
+ # This ensures the become user will have permissions for the operation
+ # Make sure to specify a folder both the ansible_user and the become_user have access to (i.e not %TEMP% which is user specific and requires Admin)
+ ansible_remote_tmp: 'c:\tmp'
+'''
+
+RETURN = r'''
+backup_file:
+ description: Name of the backup file that was created.
+ returned: if backup=yes
+ type: str
+ sample: C:\Path\To\File.txt.11540.20150212-220915.bak
+dest:
+ description: Destination file/path.
+ returned: changed
+ type: str
+ sample: C:\Temp\
+src:
+ description: Source file used for the copy on the target machine.
+ returned: changed
+ type: str
+ sample: /home/httpd/.ansible/tmp/ansible-tmp-1423796390.97-147729857856000/source
+checksum:
+ description: SHA1 checksum of the file after running copy.
+ returned: success, src is a file
+ type: str
+ sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827
+size:
+ description: Size of the target, after execution.
+ returned: changed, src is a file
+ type: int
+ sample: 1220
+operation:
+ description: Whether a single file copy took place or a folder copy.
+ returned: success
+ type: str
+ sample: file_copy
+original_basename:
+ description: Basename of the copied file.
+ returned: changed, src is a file
+ type: str
+ sample: foo.txt
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_data_deduplication.ps1 b/test/support/windows-integration/plugins/modules/win_data_deduplication.ps1
new file mode 100644
index 00000000..593ee763
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_data_deduplication.ps1
@@ -0,0 +1,129 @@
+#!powershell
+
+# Copyright: 2019, rnsc(@rnsc) <github@rnsc.be>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+#AnsibleRequires -OSVersion 6.3
+
+$spec = @{
+ options = @{
+ drive_letter = @{ type = "str"; required = $true }
+ state = @{ type = "str"; choices = "absent", "present"; default = "present"; }
+ settings = @{
+ type = "dict"
+ required = $false
+ options = @{
+ minimum_file_size = @{ type = "int"; default = 32768 }
+ minimum_file_age_days = @{ type = "int"; default = 2 }
+ no_compress = @{ type = "bool"; required = $false; default = $false }
+ optimize_in_use_files = @{ type = "bool"; required = $false; default = $false }
+ verify = @{ type = "bool"; required = $false; default = $false }
+ }
+ }
+ }
+ supports_check_mode = $true
+}
+
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec)
+
+$drive_letter = $module.Params.drive_letter
+$state = $module.Params.state
+$settings = $module.Params.settings
+
+$module.Result.changed = $false
+$module.Result.reboot_required = $false
+$module.Result.msg = ""
+
+function Set-DataDeduplication($volume, $state, $settings, $dedup_job) {
+
+ $current_state = 'absent'
+
+ try {
+ $dedup_info = Get-DedupVolume -Volume "$($volume.DriveLetter):"
+ } catch {
+ $dedup_info = $null
+ }
+
+ if ($dedup_info.Enabled) {
+ $current_state = 'present'
+ }
+
+ if ( $state -ne $current_state ) {
+ if( -not $module.CheckMode) {
+ if($state -eq 'present') {
+ # Enable-DedupVolume -Volume <String>
+ Enable-DedupVolume -Volume "$($volume.DriveLetter):"
+ } elseif ($state -eq 'absent') {
+ Disable-DedupVolume -Volume "$($volume.DriveLetter):"
+ }
+ }
+ $module.Result.changed = $true
+ }
+
+ if ($state -eq 'present') {
+ if ($null -ne $settings) {
+ Set-DataDedupJobSettings -volume $volume -settings $settings
+ }
+ }
+}
+
+function Set-DataDedupJobSettings ($volume, $settings) {
+
+ try {
+ $dedup_info = Get-DedupVolume -Volume "$($volume.DriveLetter):"
+ } catch {
+ $dedup_info = $null
+ }
+
+ ForEach ($key in $settings.keys) {
+
+ # See Microsoft documentation:
+ # https://docs.microsoft.com/en-us/powershell/module/deduplication/set-dedupvolume?view=win10-ps
+
+ $update_key = $key
+ $update_value = $settings.$($key)
+ # Transform Ansible style options to Powershell params
+ $update_key = $update_key -replace('_', '')
+
+ if ($update_key -eq "MinimumFileSize" -and $update_value -lt 32768) {
+ $update_value = 32768
+ }
+
+ $current_value = ($dedup_info | Select-Object -ExpandProperty $update_key)
+
+ if ($update_value -ne $current_value) {
+ $command_param = @{
+ $($update_key) = $update_value
+ }
+
+ # Set-DedupVolume -Volume <String>`
+ # -NoCompress <bool> `
+ # -MinimumFileAgeDays <UInt32> `
+ # -MinimumFileSize <UInt32> (minimum 32768)
+ if( -not $module.CheckMode ) {
+ Set-DedupVolume -Volume "$($volume.DriveLetter):" @command_param
+ }
+
+ $module.Result.changed = $true
+ }
+ }
+
+}
+
+# Install required feature
+$feature_name = "FS-Data-Deduplication"
+if( -not $module.CheckMode) {
+ $feature = Install-WindowsFeature -Name $feature_name
+
+ if ($feature.RestartNeeded -eq 'Yes') {
+ $module.Result.reboot_required = $true
+ $module.FailJson("$feature_name was installed but requires Windows to be rebooted to work.")
+ }
+}
+
+$volume = Get-Volume -DriveLetter $drive_letter
+
+Set-DataDeduplication -volume $volume -state $state -settings $settings -dedup_job $dedup_job
+
+$module.ExitJson()
diff --git a/test/support/windows-integration/plugins/modules/win_data_deduplication.py b/test/support/windows-integration/plugins/modules/win_data_deduplication.py
new file mode 100644
index 00000000..d320b9f7
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_data_deduplication.py
@@ -0,0 +1,87 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: 2019, rnsc(@rnsc) <github@rnsc.be>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_data_deduplication
+version_added: "2.10"
+short_description: Module to enable Data Deduplication on a volume.
+description:
+- This module can be used to enable Data Deduplication on a Windows volume.
+- The module will install the FS-Data-Deduplication feature (a reboot will be necessary).
+options:
+ drive_letter:
+ description:
+ - Windows drive letter on which to enable data deduplication.
+ required: yes
+ type: str
+ state:
+ description:
+ - Wether to enable or disable data deduplication on the selected volume.
+ default: present
+ type: str
+ choices: [ present, absent ]
+ settings:
+ description:
+ - Dictionary of settings to pass to the Set-DedupVolume powershell command.
+ type: dict
+ suboptions:
+ minimum_file_size:
+ description:
+ - Minimum file size you want to target for deduplication.
+ - It will default to 32768 if not defined or if the value is less than 32768.
+ type: int
+ default: 32768
+ minimum_file_age_days:
+ description:
+ - Minimum file age you want to target for deduplication.
+ type: int
+ default: 2
+ no_compress:
+ description:
+ - Wether you want to enabled filesystem compression or not.
+ type: bool
+ default: no
+ optimize_in_use_files:
+ description:
+ - Indicates that the server attempts to optimize currently open files.
+ type: bool
+ default: no
+ verify:
+ description:
+ - Indicates whether the deduplication engine performs a byte-for-byte verification for each duplicate chunk
+ that optimization creates, rather than relying on a cryptographically strong hash.
+ - This option is not recommend.
+ - Setting this parameter to True can degrade optimization performance.
+ type: bool
+ default: no
+author:
+- rnsc (@rnsc)
+'''
+
+EXAMPLES = r'''
+- name: Enable Data Deduplication on D
+ win_data_deduplication:
+ drive_letter: 'D'
+ state: present
+
+- name: Enable Data Deduplication on D
+ win_data_deduplication:
+ drive_letter: 'D'
+ state: present
+ settings:
+ no_compress: true
+ minimum_file_age_days: 1
+ minimum_file_size: 0
+'''
+
+RETURN = r'''
+#
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_dsc.ps1 b/test/support/windows-integration/plugins/modules/win_dsc.ps1
new file mode 100644
index 00000000..690f391a
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_dsc.ps1
@@ -0,0 +1,398 @@
+#!powershell
+
+# Copyright: (c) 2015, Trond Hindenes <trond@hindenes.com>, and others
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+#Requires -Version 5
+
+Function ConvertTo-ArgSpecType {
+ <#
+ .SYNOPSIS
+ Converts the DSC parameter type to the arg spec type required for Ansible.
+ #>
+ param(
+ [Parameter(Mandatory=$true)][String]$CimType
+ )
+
+ $arg_type = switch($CimType) {
+ Boolean { "bool" }
+ Char16 { [Func[[Object], [Char]]]{ [System.Char]::Parse($args[0].ToString()) } }
+ DateTime { [Func[[Object], [DateTime]]]{ [System.DateTime]($args[0].ToString()) } }
+ Instance { "dict" }
+ Real32 { "float" }
+ Real64 { [Func[[Object], [Double]]]{ [System.Double]::Parse($args[0].ToString()) } }
+ Reference { "dict" }
+ SInt16 { [Func[[Object], [Int16]]]{ [System.Int16]::Parse($args[0].ToString()) } }
+ SInt32 { "int" }
+ SInt64 { [Func[[Object], [Int64]]]{ [System.Int64]::Parse($args[0].ToString()) } }
+ SInt8 { [Func[[Object], [SByte]]]{ [System.SByte]::Parse($args[0].ToString()) } }
+ String { "str" }
+ UInt16 { [Func[[Object], [UInt16]]]{ [System.UInt16]::Parse($args[0].ToString()) } }
+ UInt32 { [Func[[Object], [UInt32]]]{ [System.UInt32]::Parse($args[0].ToString()) } }
+ UInt64 { [Func[[Object], [UInt64]]]{ [System.UInt64]::Parse($args[0].ToString()) } }
+ UInt8 { [Func[[Object], [Byte]]]{ [System.Byte]::Parse($args[0].ToString()) } }
+ Unknown { "raw" }
+ default { "raw" }
+ }
+ return $arg_type
+}
+
+Function Get-DscCimClassProperties {
+ <#
+ .SYNOPSIS
+ Get's a list of CimProperties of a CIM Class. It filters out any magic or
+ read only properties that we don't need to know about.
+ #>
+ param([Parameter(Mandatory=$true)][String]$ClassName)
+
+ $resource = Get-CimClass -ClassName $ClassName -Namespace root\Microsoft\Windows\DesiredStateConfiguration
+
+ # Filter out any magic properties that are used internally on an OMI_BaseResource
+ # https://github.com/PowerShell/PowerShell/blob/master/src/System.Management.Automation/DscSupport/CimDSCParser.cs#L1203
+ $magic_properties = @("ResourceId", "SourceInfo", "ModuleName", "ModuleVersion", "ConfigurationName")
+ $properties = $resource.CimClassProperties | Where-Object {
+
+ ($resource.CimSuperClassName -ne "OMI_BaseResource" -or $_.Name -notin $magic_properties) -and
+ -not $_.Flags.HasFlag([Microsoft.Management.Infrastructure.CimFlags]::ReadOnly)
+ }
+
+ return ,$properties
+}
+
+Function Add-PropertyOption {
+ <#
+ .SYNOPSIS
+ Adds the spec for the property type to the existing module specification.
+ #>
+ param(
+ [Parameter(Mandatory=$true)][Hashtable]$Spec,
+ [Parameter(Mandatory=$true)]
+ [Microsoft.Management.Infrastructure.CimPropertyDeclaration]$Property
+ )
+
+ $option = @{
+ required = $false
+ }
+ $property_name = $Property.Name
+ $property_type = $Property.CimType.ToString()
+
+ if ($Property.Flags.HasFlag([Microsoft.Management.Infrastructure.CimFlags]::Key) -or
+ $Property.Flags.HasFlag([Microsoft.Management.Infrastructure.CimFlags]::Required)) {
+ $option.required = $true
+ }
+
+ if ($null -ne $Property.Qualifiers['Values']) {
+ $option.choices = [System.Collections.Generic.List`1[Object]]$Property.Qualifiers['Values'].Value
+ }
+
+ if ($property_name -eq "Name") {
+ # For backwards compatibility we support specifying the Name DSC property as item_name
+ $option.aliases = @("item_name")
+ } elseif ($property_name -ceq "key") {
+ # There seems to be a bug in the CIM property parsing when the property name is 'Key'. The CIM instance will
+ # think the name is 'key' when the MOF actually defines it as 'Key'. We set the proper casing so the module arg
+ # validator won't fire a case sensitive warning
+ $property_name = "Key"
+ }
+
+ if ($Property.ReferenceClassName -eq "MSFT_Credential") {
+ # Special handling for the MSFT_Credential type (PSCredential), we handle this with having 2 options that
+ # have the suffix _username and _password.
+ $option_spec_pass = @{
+ type = "str"
+ required = $option.required
+ no_log = $true
+ }
+ $Spec.options."$($property_name)_password" = $option_spec_pass
+ $Spec.required_together.Add(@("$($property_name)_username", "$($property_name)_password")) > $null
+
+ $property_name = "$($property_name)_username"
+ $option.type = "str"
+ } elseif ($Property.ReferenceClassName -eq "MSFT_KeyValuePair") {
+ $option.type = "dict"
+ } elseif ($property_type.EndsWith("Array")) {
+ $option.type = "list"
+ $option.elements = ConvertTo-ArgSpecType -CimType $property_type.Substring(0, $property_type.Length - 5)
+ } else {
+ $option.type = ConvertTo-ArgSpecType -CimType $property_type
+ }
+
+ if (($option.type -eq "dict" -or ($option.type -eq "list" -and $option.elements -eq "dict")) -and
+ $Property.ReferenceClassName -ne "MSFT_KeyValuePair") {
+ # Get the sub spec if the type is a Instance (CimInstance/dict)
+ $sub_option_spec = Get-OptionSpec -ClassName $Property.ReferenceClassName
+ $option += $sub_option_spec
+ }
+
+ $Spec.options.$property_name = $option
+}
+
+Function Get-OptionSpec {
+ <#
+ .SYNOPSIS
+ Generates the specifiec used in AnsibleModule for a CIM MOF resource name.
+
+ .NOTES
+ This won't be able to retrieve the default values for an option as that is not defined in the MOF for a resource.
+ Default values are still preserved in the DSC engine if we don't pass in the property at all, we just can't report
+ on what they are automatically.
+ #>
+ param(
+ [Parameter(Mandatory=$true)][String]$ClassName
+ )
+
+ $spec = @{
+ options = @{}
+ required_together = [System.Collections.ArrayList]@()
+ }
+ $properties = Get-DscCimClassProperties -ClassName $ClassName
+ foreach ($property in $properties) {
+ Add-PropertyOption -Spec $spec -Property $property
+ }
+
+ return $spec
+}
+
+Function ConvertTo-CimInstance {
+ <#
+ .SYNOPSIS
+ Converts a dict to a CimInstance of the specified Class. Also provides a
+ better error message if this fails that contains the option name that failed.
+ #>
+ param(
+ [Parameter(Mandatory=$true)][String]$Name,
+ [Parameter(Mandatory=$true)][String]$ClassName,
+ [Parameter(Mandatory=$true)][System.Collections.IDictionary]$Value,
+ [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module,
+ [Switch]$Recurse
+ )
+
+ $properties = @{}
+ foreach ($value_info in $Value.GetEnumerator()) {
+ # Need to remove all null values from existing dict so the conversion works
+ if ($null -eq $value_info.Value) {
+ continue
+ }
+ $properties.($value_info.Key) = $value_info.Value
+ }
+
+ if ($Recurse) {
+ # We want to validate and convert and values to what's required by DSC
+ $properties = ConvertTo-DscProperty -ClassName $ClassName -Params $properties -Module $Module
+ }
+
+ try {
+ return (New-CimInstance -ClassName $ClassName -Property $properties -ClientOnly)
+ } catch {
+ # New-CimInstance raises a poor error message, make sure we mention what option it is for
+ $Module.FailJson("Failed to cast dict value for option '$Name' to a CimInstance: $($_.Exception.Message)", $_)
+ }
+}
+
+Function ConvertTo-DscProperty {
+ <#
+ .SYNOPSIS
+ Converts the input module parameters that have been validated and casted
+ into the types expected by the DSC engine. This is mostly done to deal with
+ types like PSCredential and Dictionaries.
+ #>
+ param(
+ [Parameter(Mandatory=$true)][String]$ClassName,
+ [Parameter(Mandatory=$true)][System.Collections.IDictionary]$Params,
+ [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module
+ )
+ $properties = Get-DscCimClassProperties -ClassName $ClassName
+
+ $dsc_properties = @{}
+ foreach ($property in $properties) {
+ $property_name = $property.Name
+ $property_type = $property.CimType.ToString()
+
+ if ($property.ReferenceClassName -eq "MSFT_Credential") {
+ $username = $Params."$($property_name)_username"
+ $password = $Params."$($property_name)_password"
+
+ # No user set == No option set in playbook, skip this property
+ if ($null -eq $username) {
+ continue
+ }
+ $sec_password = ConvertTo-SecureString -String $password -AsPlainText -Force
+ $value = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList $username, $sec_password
+ } else {
+ $value = $Params.$property_name
+
+ # The actual value wasn't set, skip adding this property
+ if ($null -eq $value) {
+ continue
+ }
+
+ if ($property.ReferenceClassName -eq "MSFT_KeyValuePair") {
+ $key_value_pairs = [System.Collections.Generic.List`1[CimInstance]]@()
+ foreach ($value_info in $value.GetEnumerator()) {
+ $kvp = @{Key = $value_info.Key; Value = $value_info.Value.ToString()}
+ $cim_instance = ConvertTo-CimInstance -Name $property_name -ClassName MSFT_KeyValuePair `
+ -Value $kvp -Module $Module
+ $key_value_pairs.Add($cim_instance) > $null
+ }
+ $value = $key_value_pairs.ToArray()
+ } elseif ($null -ne $property.ReferenceClassName) {
+ # Convert the dict to a CimInstance (or list of CimInstances)
+ $convert_args = @{
+ ClassName = $property.ReferenceClassName
+ Module = $Module
+ Name = $property_name
+ Recurse = $true
+ }
+ if ($property_type.EndsWith("Array")) {
+ $value = [System.Collections.Generic.List`1[CimInstance]]@()
+ foreach ($raw in $Params.$property_name.GetEnumerator()) {
+ $cim_instance = ConvertTo-CimInstance -Value $raw @convert_args
+ $value.Add($cim_instance) > $null
+ }
+ $value = $value.ToArray() # Need to make sure we are dealing with an Array not a List
+ } else {
+ $value = ConvertTo-CimInstance -Value $value @convert_args
+ }
+ }
+ }
+ $dsc_properties.$property_name = $value
+ }
+
+ return $dsc_properties
+}
+
+Function Invoke-DscMethod {
+ <#
+ .SYNOPSIS
+ Invokes the DSC Resource Method specified in another PS pipeline. This is
+ done so we can retrieve the Verbose stream and return it back to the user
+ for futher debugging.
+ #>
+ param(
+ [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module,
+ [Parameter(Mandatory=$true)][String]$Method,
+ [Parameter(Mandatory=$true)][Hashtable]$Arguments
+ )
+
+ # Invoke the DSC resource in a separate runspace so we can capture the Verbose output
+ $ps = [PowerShell]::Create()
+ $ps.AddCommand("Invoke-DscResource").AddParameter("Method", $Method) > $null
+ $ps.AddParameters($Arguments) > $null
+
+ $result = $ps.Invoke()
+
+ # Pass the warnings through to the AnsibleModule return result
+ foreach ($warning in $ps.Streams.Warning) {
+ $Module.Warn($warning.Message)
+ }
+
+ # If running at a high enough verbosity, add the verbose output to the AnsibleModule return result
+ if ($Module.Verbosity -ge 3) {
+ $verbose_logs = [System.Collections.Generic.List`1[String]]@()
+ foreach ($verbosity in $ps.Streams.Verbose) {
+ $verbose_logs.Add($verbosity.Message) > $null
+ }
+ $Module.Result."verbose_$($Method.ToLower())" = $verbose_logs
+ }
+
+ if ($ps.HadErrors) {
+ # Cannot pass in the ErrorRecord as it's a RemotingErrorRecord and doesn't contain the ScriptStackTrace
+ # or other info that would be useful
+ $Module.FailJson("Failed to invoke DSC $Method method: $($ps.Streams.Error[0].Exception.Message)")
+ }
+
+ return $result
+}
+
+# win_dsc is unique in that is builds the arg spec based on DSC Resource input. To get this info
+# we need to read the resource_name and module_version value which is done outside of Ansible.Basic
+if ($args.Length -gt 0) {
+ $params = Get-Content -Path $args[0] | ConvertFrom-Json
+} else {
+ $params = $complex_args
+}
+if (-not $params.ContainsKey("resource_name")) {
+ $res = @{
+ msg = "missing required argument: resource_name"
+ failed = $true
+ }
+ Write-Output -InputObject (ConvertTo-Json -Compress -InputObject $res)
+ exit 1
+}
+$resource_name = $params.resource_name
+
+if ($params.ContainsKey("module_version")) {
+ $module_version = $params.module_version
+} else {
+ $module_version = "latest"
+}
+
+$module_versions = (Get-DscResource -Name $resource_name -ErrorAction SilentlyContinue | Sort-Object -Property Version)
+$resource = $null
+if ($module_version -eq "latest" -and $null -ne $module_versions) {
+ $resource = $module_versions[-1]
+} elseif ($module_version -ne "latest") {
+ $resource = $module_versions | Where-Object { $_.Version -eq $module_version }
+}
+
+if (-not $resource) {
+ if ($module_version -eq "latest") {
+ $msg = "Resource '$resource_name' not found."
+ } else {
+ $msg = "Resource '$resource_name' with version '$module_version' not found."
+ $msg += " Versions installed: '$($module_versions.Version -join "', '")'."
+ }
+
+ Write-Output -InputObject (ConvertTo-Json -Compress -InputObject @{ failed = $true; msg = $msg })
+ exit 1
+}
+
+# Build the base args for the DSC Invocation based on the resource selected
+$dsc_args = @{
+ Name = $resource.Name
+}
+
+# Binary resources are not working very well with that approach - need to guesstimate module name/version
+$module_version = $null
+if ($resource.Module) {
+ $dsc_args.ModuleName = @{
+ ModuleName = $resource.Module.Name
+ ModuleVersion = $resource.Module.Version
+ }
+ $module_version = $resource.Module.Version.ToString()
+} else {
+ $dsc_args.ModuleName = "PSDesiredStateConfiguration"
+}
+
+# To ensure the class registered with CIM is the one based on our version, we want to run the Get method so the DSC
+# engine updates the metadata propery. We don't care about any errors here
+try {
+ Invoke-DscResource -Method Get -Property @{Fake="Fake"} @dsc_args > $null
+} catch {}
+
+# Dynamically build the option spec based on the resource_name specified and create the module object
+$spec = Get-OptionSpec -ClassName $resource.ResourceType
+$spec.supports_check_mode = $true
+$spec.options.module_version = @{ type = "str"; default = "latest" }
+$spec.options.resource_name = @{ type = "str"; required = $true }
+
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec)
+$module.Result.reboot_required = $false
+$module.Result.module_version = $module_version
+
+# Build the DSC invocation arguments and invoke the resource
+$dsc_args.Property = ConvertTo-DscProperty -ClassName $resource.ResourceType -Module $module -Params $Module.Params
+$dsc_args.Verbose = $true
+
+$test_result = Invoke-DscMethod -Module $module -Method Test -Arguments $dsc_args
+if ($test_result.InDesiredState -ne $true) {
+ if (-not $module.CheckMode) {
+ $result = Invoke-DscMethod -Module $module -Method Set -Arguments $dsc_args
+ $module.Result.reboot_required = $result.RebootRequired
+ }
+ $module.Result.changed = $true
+}
+
+$module.ExitJson()
diff --git a/test/support/windows-integration/plugins/modules/win_dsc.py b/test/support/windows-integration/plugins/modules/win_dsc.py
new file mode 100644
index 00000000..200d025e
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_dsc.py
@@ -0,0 +1,183 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2015, Trond Hindenes <trond@hindenes.com>, and others
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_dsc
+version_added: "2.4"
+short_description: Invokes a PowerShell DSC configuration
+description:
+- Configures a resource using PowerShell DSC.
+- Requires PowerShell version 5.0 or newer.
+- Most of the options for this module are dynamic and will vary depending on
+ the DSC Resource specified in I(resource_name).
+- See :doc:`/user_guide/windows_dsc` for more information on how to use this module.
+options:
+ resource_name:
+ description:
+ - The name of the DSC Resource to use.
+ - Must be accessible to PowerShell using any of the default paths.
+ type: str
+ required: yes
+ module_version:
+ description:
+ - Can be used to configure the exact version of the DSC resource to be
+ invoked.
+ - Useful if the target node has multiple versions installed of the module
+ containing the DSC resource.
+ - If not specified, the module will follow standard PowerShell convention
+ and use the highest version available.
+ type: str
+ default: latest
+ free_form:
+ description:
+ - The M(win_dsc) module takes in multiple free form options based on the
+ DSC resource being invoked by I(resource_name).
+ - There is no option actually named C(free_form) so see the examples.
+ - This module will try and convert the option to the correct type required
+ by the DSC resource and throw a warning if it fails.
+ - If the type of the DSC resource option is a C(CimInstance) or
+ C(CimInstance[]), this means the value should be a dictionary or list
+ of dictionaries based on the values required by that option.
+ - If the type of the DSC resource option is a C(PSCredential) then there
+ needs to be 2 options set in the Ansible task definition suffixed with
+ C(_username) and C(_password).
+ - If the type of the DSC resource option is an array, then a list should be
+ provided but a comma separated string also work. Use a list where
+ possible as no escaping is required and it works with more complex types
+ list C(CimInstance[]).
+ - If the type of the DSC resource option is a C(DateTime), you should use
+ a string in the form of an ISO 8901 string to ensure the exact date is
+ used.
+ - Since Ansible 2.8, Ansible will now validate the input fields against the
+ DSC resource definition automatically. Older versions will silently
+ ignore invalid fields.
+ type: str
+ required: true
+notes:
+- By default there are a few builtin resources that come with PowerShell 5.0,
+ see U(https://docs.microsoft.com/en-us/powershell/scripting/dsc/resources/resources) for
+ more information on these resources.
+- Custom DSC resources can be installed with M(win_psmodule) using the I(name)
+ option.
+- The DSC engine run's each task as the SYSTEM account, any resources that need
+ to be accessed with a different account need to have C(PsDscRunAsCredential)
+ set.
+- To see the valid options for a DSC resource, run the module with C(-vvv) to
+ show the possible module invocation. Default values are not shown in this
+ output but are applied within the DSC engine.
+author:
+- Trond Hindenes (@trondhindenes)
+'''
+
+EXAMPLES = r'''
+- name: Extract zip file
+ win_dsc:
+ resource_name: Archive
+ Ensure: Present
+ Path: C:\Temp\zipfile.zip
+ Destination: C:\Temp\Temp2
+
+- name: Install a Windows feature with the WindowsFeature resource
+ win_dsc:
+ resource_name: WindowsFeature
+ Name: telnet-client
+
+- name: Edit HKCU reg key under specific user
+ win_dsc:
+ resource_name: Registry
+ Ensure: Present
+ Key: HKEY_CURRENT_USER\ExampleKey
+ ValueName: TestValue
+ ValueData: TestData
+ PsDscRunAsCredential_username: '{{ansible_user}}'
+ PsDscRunAsCredential_password: '{{ansible_password}}'
+ no_log: true
+
+- name: Create file with multiple attributes
+ win_dsc:
+ resource_name: File
+ DestinationPath: C:\ansible\dsc
+ Attributes: # can also be a comma separated string, e.g. 'Hidden, System'
+ - Hidden
+ - System
+ Ensure: Present
+ Type: Directory
+
+- name: Call DSC resource with DateTime option
+ win_dsc:
+ resource_name: DateTimeResource
+ DateTimeOption: '2019-02-22T13:57:31.2311892+00:00'
+
+# more complex example using custom DSC resource and dict values
+- name: Setup the xWebAdministration module
+ win_psmodule:
+ name: xWebAdministration
+ state: present
+
+- name: Create IIS Website with Binding and Authentication options
+ win_dsc:
+ resource_name: xWebsite
+ Ensure: Present
+ Name: DSC Website
+ State: Started
+ PhysicalPath: C:\inetpub\wwwroot
+ BindingInfo: # Example of a CimInstance[] DSC parameter (list of dicts)
+ - Protocol: https
+ Port: 1234
+ CertificateStoreName: MY
+ CertificateThumbprint: C676A89018C4D5902353545343634F35E6B3A659
+ HostName: DSCTest
+ IPAddress: '*'
+ SSLFlags: '1'
+ - Protocol: http
+ Port: 4321
+ IPAddress: '*'
+ AuthenticationInfo: # Example of a CimInstance DSC parameter (dict)
+ Anonymous: no
+ Basic: true
+ Digest: false
+ Windows: yes
+'''
+
+RETURN = r'''
+module_version:
+ description: The version of the dsc resource/module used.
+ returned: always
+ type: str
+ sample: "1.0.1"
+reboot_required:
+ description: Flag returned from the DSC engine indicating whether or not
+ the machine requires a reboot for the invoked changes to take effect.
+ returned: always
+ type: bool
+ sample: true
+verbose_test:
+ description: The verbose output as a list from executing the DSC test
+ method.
+ returned: Ansible verbosity is -vvv or greater
+ type: list
+ sample: [
+ "Perform operation 'Invoke CimMethod' with the following parameters, ",
+ "[SERVER]: LCM: [Start Test ] [[File]DirectResourceAccess]",
+ "Operation 'Invoke CimMethod' complete."
+ ]
+verbose_set:
+ description: The verbose output as a list from executing the DSC Set
+ method.
+ returned: Ansible verbosity is -vvv or greater and a change occurred
+ type: list
+ sample: [
+ "Perform operation 'Invoke CimMethod' with the following parameters, ",
+ "[SERVER]: LCM: [Start Set ] [[File]DirectResourceAccess]",
+ "Operation 'Invoke CimMethod' complete."
+ ]
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_feature.ps1 b/test/support/windows-integration/plugins/modules/win_feature.ps1
new file mode 100644
index 00000000..9a7e1c30
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_feature.ps1
@@ -0,0 +1,111 @@
+#!powershell
+
+# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+
+Import-Module -Name ServerManager
+
+$result = @{
+ changed = $false
+}
+
+$params = Parse-Args $args -supports_check_mode $true
+$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false
+
+$name = Get-AnsibleParam -obj $params -name "name" -type "list" -failifempty $true
+$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent"
+
+$include_sub_features = Get-AnsibleParam -obj $params -name "include_sub_features" -type "bool" -default $false
+$include_management_tools = Get-AnsibleParam -obj $params -name "include_management_tools" -type "bool" -default $false
+$source = Get-AnsibleParam -obj $params -name "source" -type "str"
+
+$install_cmdlet = $false
+if (Get-Command -Name Install-WindowsFeature -ErrorAction SilentlyContinue) {
+ Set-Alias -Name Install-AnsibleWindowsFeature -Value Install-WindowsFeature
+ Set-Alias -Name Uninstall-AnsibleWindowsFeature -Value Uninstall-WindowsFeature
+ $install_cmdlet = $true
+} elseif (Get-Command -Name Add-WindowsFeature -ErrorAction SilentlyContinue) {
+ Set-Alias -Name Install-AnsibleWindowsFeature -Value Add-WindowsFeature
+ Set-Alias -Name Uninstall-AnsibleWindowsFeature -Value Remove-WindowsFeature
+} else {
+ Fail-Json -obj $result -message "This version of Windows does not support the cmdlets Install-WindowsFeature or Add-WindowsFeature"
+}
+
+if ($state -eq "present") {
+ $install_args = @{
+ Name = $name
+ IncludeAllSubFeature = $include_sub_features
+ Restart = $false
+ WhatIf = $check_mode
+ ErrorAction = "Stop"
+ }
+
+ if ($install_cmdlet) {
+ $install_args.IncludeManagementTools = $include_management_tools
+ $install_args.Confirm = $false
+ if ($source) {
+ if (-not (Test-Path -Path $source)) {
+ Fail-Json -obj $result -message "Failed to find source path $source for feature install"
+ }
+ $install_args.Source = $source
+ }
+ }
+
+ try {
+ $action_results = Install-AnsibleWindowsFeature @install_args
+ } catch {
+ Fail-Json -obj $result -message "Failed to install Windows Feature: $($_.Exception.Message)"
+ }
+} else {
+ $uninstall_args = @{
+ Name = $name
+ Restart = $false
+ WhatIf = $check_mode
+ ErrorAction = "Stop"
+ }
+ if ($install_cmdlet) {
+ $uninstall_args.IncludeManagementTools = $include_management_tools
+ }
+
+ try {
+ $action_results = Uninstall-AnsibleWindowsFeature @uninstall_args
+ } catch {
+ Fail-Json -obj $result -message "Failed to uninstall Windows Feature: $($_.Exception.Message)"
+ }
+}
+
+# Loop through results and create a hash containing details about
+# each role/feature that is installed/removed
+# $action_results.FeatureResult is not empty if anything was changed
+$feature_results = @()
+foreach ($action_result in $action_results.FeatureResult) {
+ $message = @()
+ foreach ($msg in $action_result.Message) {
+ $message += @{
+ message_type = $msg.MessageType.ToString()
+ error_code = $msg.ErrorCode
+ text = $msg.Text
+ }
+ }
+
+ $feature_results += @{
+ id = $action_result.Id
+ display_name = $action_result.DisplayName
+ message = $message
+ reboot_required = ConvertTo-Bool -obj $action_result.RestartNeeded
+ skip_reason = $action_result.SkipReason.ToString()
+ success = ConvertTo-Bool -obj $action_result.Success
+ restart_needed = ConvertTo-Bool -obj $action_result.RestartNeeded
+ }
+ $result.changed = $true
+}
+$result.feature_result = $feature_results
+$result.success = ConvertTo-Bool -obj $action_results.Success
+$result.exitcode = $action_results.ExitCode.ToString()
+$result.reboot_required = ConvertTo-Bool -obj $action_results.RestartNeeded
+# controls whether Ansible will fail or not
+$result.failed = (-not $action_results.Success)
+
+Exit-Json -obj $result
diff --git a/test/support/windows-integration/plugins/modules/win_feature.py b/test/support/windows-integration/plugins/modules/win_feature.py
new file mode 100644
index 00000000..62e310b2
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_feature.py
@@ -0,0 +1,149 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com>
+# Copyright: (c) 2014, Trond Hindenes <trond@hindenes.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub. actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_feature
+version_added: "1.7"
+short_description: Installs and uninstalls Windows Features on Windows Server
+description:
+ - Installs or uninstalls Windows Roles or Features on Windows Server.
+ - This module uses the Add/Remove-WindowsFeature Cmdlets on Windows 2008 R2
+ and Install/Uninstall-WindowsFeature Cmdlets on Windows 2012, which are not available on client os machines.
+options:
+ name:
+ description:
+ - Names of roles or features to install as a single feature or a comma-separated list of features.
+ - To list all available features use the PowerShell command C(Get-WindowsFeature).
+ type: list
+ required: yes
+ state:
+ description:
+ - State of the features or roles on the system.
+ type: str
+ choices: [ absent, present ]
+ default: present
+ include_sub_features:
+ description:
+ - Adds all subfeatures of the specified feature.
+ type: bool
+ default: no
+ include_management_tools:
+ description:
+ - Adds the corresponding management tools to the specified feature.
+ - Not supported in Windows 2008 R2 and will be ignored.
+ type: bool
+ default: no
+ source:
+ description:
+ - Specify a source to install the feature from.
+ - Not supported in Windows 2008 R2 and will be ignored.
+ - Can either be C({driveletter}:\sources\sxs) or C(\\{IP}\share\sources\sxs).
+ type: str
+ version_added: "2.1"
+seealso:
+- module: win_chocolatey
+- module: win_package
+author:
+ - Paul Durivage (@angstwad)
+ - Trond Hindenes (@trondhindenes)
+'''
+
+EXAMPLES = r'''
+- name: Install IIS (Web-Server only)
+ win_feature:
+ name: Web-Server
+ state: present
+
+- name: Install IIS (Web-Server and Web-Common-Http)
+ win_feature:
+ name:
+ - Web-Server
+ - Web-Common-Http
+ state: present
+
+- name: Install NET-Framework-Core from file
+ win_feature:
+ name: NET-Framework-Core
+ source: C:\Temp\iso\sources\sxs
+ state: present
+
+- name: Install IIS Web-Server with sub features and management tools
+ win_feature:
+ name: Web-Server
+ state: present
+ include_sub_features: yes
+ include_management_tools: yes
+ register: win_feature
+
+- name: Reboot if installing Web-Server feature requires it
+ win_reboot:
+ when: win_feature.reboot_required
+'''
+
+RETURN = r'''
+exitcode:
+ description: The stringified exit code from the feature installation/removal command.
+ returned: always
+ type: str
+ sample: Success
+feature_result:
+ description: List of features that were installed or removed.
+ returned: success
+ type: complex
+ sample:
+ contains:
+ display_name:
+ description: Feature display name.
+ returned: always
+ type: str
+ sample: "Telnet Client"
+ id:
+ description: A list of KB article IDs that apply to the update.
+ returned: always
+ type: int
+ sample: 44
+ message:
+ description: Any messages returned from the feature subsystem that occurred during installation or removal of this feature.
+ returned: always
+ type: list
+ elements: str
+ sample: []
+ reboot_required:
+ description: True when the target server requires a reboot as a result of installing or removing this feature.
+ returned: always
+ type: bool
+ sample: true
+ restart_needed:
+ description: DEPRECATED in Ansible 2.4 (refer to C(reboot_required) instead). True when the target server requires a reboot as a
+ result of installing or removing this feature.
+ returned: always
+ type: bool
+ sample: true
+ skip_reason:
+ description: The reason a feature installation or removal was skipped.
+ returned: always
+ type: str
+ sample: NotSkipped
+ success:
+ description: If the feature installation or removal was successful.
+ returned: always
+ type: bool
+ sample: true
+reboot_required:
+ description: True when the target server requires a reboot to complete updates (no further updates can be installed until after a reboot).
+ returned: success
+ type: bool
+ sample: true
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_file.ps1 b/test/support/windows-integration/plugins/modules/win_file.ps1
new file mode 100644
index 00000000..54427549
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_file.ps1
@@ -0,0 +1,152 @@
+#!powershell
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+
+$ErrorActionPreference = "Stop"
+
+$params = Parse-Args $args -supports_check_mode $true
+
+$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -default $false
+$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP
+
+$path = Get-AnsibleParam -obj $params -name "path" -type "path" -failifempty $true -aliases "dest","name"
+$state = Get-AnsibleParam -obj $params -name "state" -type "str" -validateset "absent","directory","file","touch"
+
+# used in template/copy when dest is the path to a dir and source is a file
+$original_basename = Get-AnsibleParam -obj $params -name "_original_basename" -type "str"
+if ((Test-Path -LiteralPath $path -PathType Container) -and ($null -ne $original_basename)) {
+ $path = Join-Path -Path $path -ChildPath $original_basename
+}
+
+$result = @{
+ changed = $false
+}
+
+# Used to delete symlinks as powershell cannot delete broken symlinks
+$symlink_util = @"
+using System;
+using System.ComponentModel;
+using System.Runtime.InteropServices;
+
+namespace Ansible.Command {
+ public class SymLinkHelper {
+ [DllImport("kernel32.dll", CharSet=CharSet.Unicode, SetLastError=true)]
+ public static extern bool DeleteFileW(string lpFileName);
+
+ [DllImport("kernel32.dll", CharSet=CharSet.Unicode, SetLastError=true)]
+ public static extern bool RemoveDirectoryW(string lpPathName);
+
+ public static void DeleteDirectory(string path) {
+ if (!RemoveDirectoryW(path))
+ throw new Exception(String.Format("RemoveDirectoryW({0}) failed: {1}", path, new Win32Exception(Marshal.GetLastWin32Error()).Message));
+ }
+
+ public static void DeleteFile(string path) {
+ if (!DeleteFileW(path))
+ throw new Exception(String.Format("DeleteFileW({0}) failed: {1}", path, new Win32Exception(Marshal.GetLastWin32Error()).Message));
+ }
+ }
+}
+"@
+$original_tmp = $env:TMP
+$env:TMP = $_remote_tmp
+Add-Type -TypeDefinition $symlink_util
+$env:TMP = $original_tmp
+
+# Used to delete directories and files with logic on handling symbolic links
+function Remove-File($file, $checkmode) {
+ try {
+ if ($file.Attributes -band [System.IO.FileAttributes]::ReparsePoint) {
+ # Bug with powershell, if you try and delete a symbolic link that is pointing
+ # to an invalid path it will fail, using Win32 API to do this instead
+ if ($file.PSIsContainer) {
+ if (-not $checkmode) {
+ [Ansible.Command.SymLinkHelper]::DeleteDirectory($file.FullName)
+ }
+ } else {
+ if (-not $checkmode) {
+ [Ansible.Command.SymlinkHelper]::DeleteFile($file.FullName)
+ }
+ }
+ } elseif ($file.PSIsContainer) {
+ Remove-Directory -directory $file -checkmode $checkmode
+ } else {
+ Remove-Item -LiteralPath $file.FullName -Force -WhatIf:$checkmode
+ }
+ } catch [Exception] {
+ Fail-Json $result "Failed to delete $($file.FullName): $($_.Exception.Message)"
+ }
+}
+
+function Remove-Directory($directory, $checkmode) {
+ foreach ($file in Get-ChildItem -LiteralPath $directory.FullName) {
+ Remove-File -file $file -checkmode $checkmode
+ }
+ Remove-Item -LiteralPath $directory.FullName -Force -Recurse -WhatIf:$checkmode
+}
+
+
+if ($state -eq "touch") {
+ if (Test-Path -LiteralPath $path) {
+ if (-not $check_mode) {
+ (Get-ChildItem -LiteralPath $path).LastWriteTime = Get-Date
+ }
+ $result.changed = $true
+ } else {
+ Write-Output $null | Out-File -LiteralPath $path -Encoding ASCII -WhatIf:$check_mode
+ $result.changed = $true
+ }
+}
+
+if (Test-Path -LiteralPath $path) {
+ $fileinfo = Get-Item -LiteralPath $path -Force
+ if ($state -eq "absent") {
+ Remove-File -file $fileinfo -checkmode $check_mode
+ $result.changed = $true
+ } else {
+ if ($state -eq "directory" -and -not $fileinfo.PsIsContainer) {
+ Fail-Json $result "path $path is not a directory"
+ }
+
+ if ($state -eq "file" -and $fileinfo.PsIsContainer) {
+ Fail-Json $result "path $path is not a file"
+ }
+ }
+
+} else {
+
+ # If state is not supplied, test the $path to see if it looks like
+ # a file or a folder and set state to file or folder
+ if ($null -eq $state) {
+ $basename = Split-Path -Path $path -Leaf
+ if ($basename.length -gt 0) {
+ $state = "file"
+ } else {
+ $state = "directory"
+ }
+ }
+
+ if ($state -eq "directory") {
+ try {
+ New-Item -Path $path -ItemType Directory -WhatIf:$check_mode | Out-Null
+ } catch {
+ if ($_.CategoryInfo.Category -eq "ResourceExists") {
+ $fileinfo = Get-Item -LiteralPath $_.CategoryInfo.TargetName
+ if ($state -eq "directory" -and -not $fileinfo.PsIsContainer) {
+ Fail-Json $result "path $path is not a directory"
+ }
+ } else {
+ Fail-Json $result $_.Exception.Message
+ }
+ }
+ $result.changed = $true
+ } elseif ($state -eq "file") {
+ Fail-Json $result "path $path will not be created"
+ }
+
+}
+
+Exit-Json $result
diff --git a/test/support/windows-integration/plugins/modules/win_file.py b/test/support/windows-integration/plugins/modules/win_file.py
new file mode 100644
index 00000000..28149579
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_file.py
@@ -0,0 +1,70 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_file
+version_added: "1.9.2"
+short_description: Creates, touches or removes files or directories
+description:
+ - Creates (empty) files, updates file modification stamps of existing files,
+ and can create or remove directories.
+ - Unlike M(file), does not modify ownership, permissions or manipulate links.
+ - For non-Windows targets, use the M(file) module instead.
+options:
+ path:
+ description:
+ - Path to the file being managed.
+ required: yes
+ type: path
+ aliases: [ dest, name ]
+ state:
+ description:
+ - If C(directory), all immediate subdirectories will be created if they
+ do not exist.
+ - If C(file), the file will NOT be created if it does not exist, see the M(copy)
+ or M(template) module if you want that behavior.
+ - If C(absent), directories will be recursively deleted, and files will be removed.
+ - If C(touch), an empty file will be created if the C(path) does not
+ exist, while an existing file or directory will receive updated file access and
+ modification times (similar to the way C(touch) works from the command line).
+ type: str
+ choices: [ absent, directory, file, touch ]
+seealso:
+- module: file
+- module: win_acl
+- module: win_acl_inheritance
+- module: win_owner
+- module: win_stat
+author:
+- Jon Hawkesworth (@jhawkesworth)
+'''
+
+EXAMPLES = r'''
+- name: Touch a file (creates if not present, updates modification time if present)
+ win_file:
+ path: C:\Temp\foo.conf
+ state: touch
+
+- name: Remove a file, if present
+ win_file:
+ path: C:\Temp\foo.conf
+ state: absent
+
+- name: Create directory structure
+ win_file:
+ path: C:\Temp\folder\subfolder
+ state: directory
+
+- name: Remove directory structure
+ win_file:
+ path: C:\Temp
+ state: absent
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_find.ps1 b/test/support/windows-integration/plugins/modules/win_find.ps1
new file mode 100644
index 00000000..bc57c5ff
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_find.ps1
@@ -0,0 +1,416 @@
+#!powershell
+
+# Copyright: (c) 2016, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+#Requires -Module Ansible.ModuleUtils.LinkUtil
+
+$spec = @{
+ options = @{
+ paths = @{ type = "list"; elements = "str"; required = $true }
+ age = @{ type = "str" }
+ age_stamp = @{ type = "str"; default = "mtime"; choices = "mtime", "ctime", "atime" }
+ file_type = @{ type = "str"; default = "file"; choices = "file", "directory" }
+ follow = @{ type = "bool"; default = $false }
+ hidden = @{ type = "bool"; default = $false }
+ patterns = @{ type = "list"; elements = "str"; aliases = "regex", "regexp" }
+ recurse = @{ type = "bool"; default = $false }
+ size = @{ type = "str" }
+ use_regex = @{ type = "bool"; default = $false }
+ get_checksum = @{ type = "bool"; default = $true }
+ checksum_algorithm = @{ type = "str"; default = "sha1"; choices = "md5", "sha1", "sha256", "sha384", "sha512" }
+ }
+ supports_check_mode = $true
+}
+
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec)
+
+$paths = $module.Params.paths
+$age = $module.Params.age
+$age_stamp = $module.Params.age_stamp
+$file_type = $module.Params.file_type
+$follow = $module.Params.follow
+$hidden = $module.Params.hidden
+$patterns = $module.Params.patterns
+$recurse = $module.Params.recurse
+$size = $module.Params.size
+$use_regex = $module.Params.use_regex
+$get_checksum = $module.Params.get_checksum
+$checksum_algorithm = $module.Params.checksum_algorithm
+
+$module.Result.examined = 0
+$module.Result.files = @()
+$module.Result.matched = 0
+
+Load-LinkUtils
+
+Function Assert-Age {
+ Param (
+ [System.IO.FileSystemInfo]$File,
+ [System.Int64]$Age,
+ [System.String]$AgeStamp
+ )
+
+ $actual_age = switch ($AgeStamp) {
+ mtime { $File.LastWriteTime.Ticks }
+ ctime { $File.CreationTime.Ticks }
+ atime { $File.LastAccessTime.Ticks }
+ }
+
+ if ($Age -ge 0) {
+ return $Age -ge $actual_age
+ } else {
+ return ($Age * -1) -le $actual_age
+ }
+}
+
+Function Assert-FileType {
+ Param (
+ [System.IO.FileSystemInfo]$File,
+ [System.String]$FileType
+ )
+
+ $is_dir = $File.Attributes.HasFlag([System.IO.FileAttributes]::Directory)
+ return ($FileType -eq 'directory' -and $is_dir) -or ($FileType -eq 'file' -and -not $is_dir)
+}
+
+Function Assert-FileHidden {
+ Param (
+ [System.IO.FileSystemInfo]$File,
+ [Switch]$IsHidden
+ )
+
+ $file_is_hidden = $File.Attributes.HasFlag([System.IO.FileAttributes]::Hidden)
+ return $IsHidden.IsPresent -eq $file_is_hidden
+}
+
+
+Function Assert-FileNamePattern {
+ Param (
+ [System.IO.FileSystemInfo]$File,
+ [System.String[]]$Patterns,
+ [Switch]$UseRegex
+ )
+
+ $valid_match = $false
+ foreach ($pattern in $Patterns) {
+ if ($UseRegex) {
+ if ($File.Name -match $pattern) {
+ $valid_match = $true
+ break
+ }
+ } else {
+ if ($File.Name -like $pattern) {
+ $valid_match = $true
+ break
+ }
+ }
+ }
+ return $valid_match
+}
+
+Function Assert-FileSize {
+ Param (
+ [System.IO.FileSystemInfo]$File,
+ [System.Int64]$Size
+ )
+
+ if ($Size -ge 0) {
+ return $File.Length -ge $Size
+ } else {
+ return $File.Length -le ($Size * -1)
+ }
+}
+
+Function Get-FileChecksum {
+ Param (
+ [System.String]$Path,
+ [System.String]$Algorithm
+ )
+
+ $sp = switch ($algorithm) {
+ 'md5' { New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider }
+ 'sha1' { New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider }
+ 'sha256' { New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider }
+ 'sha384' { New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider }
+ 'sha512' { New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider }
+ }
+
+ $fp = [System.IO.File]::Open($Path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read, [System.IO.FileShare]::ReadWrite)
+ try {
+ $hash = [System.BitConverter]::ToString($sp.ComputeHash($fp)).Replace("-", "").ToLower()
+ } finally {
+ $fp.Dispose()
+ }
+
+ return $hash
+}
+
+Function Search-Path {
+ [CmdletBinding()]
+ Param (
+ [Parameter(Mandatory=$true)]
+ [System.String]
+ $Path,
+
+ [Parameter(Mandatory=$true)]
+ [AllowEmptyCollection()]
+ [System.Collections.Generic.HashSet`1[System.String]]
+ $CheckedPaths,
+
+ [Parameter(Mandatory=$true)]
+ [Object]
+ $Module,
+
+ [System.Int64]
+ $Age,
+
+ [System.String]
+ $AgeStamp,
+
+ [System.String]
+ $FileType,
+
+ [Switch]
+ $Follow,
+
+ [Switch]
+ $GetChecksum,
+
+ [Switch]
+ $IsHidden,
+
+ [System.String[]]
+ $Patterns,
+
+ [Switch]
+ $Recurse,
+
+ [System.Int64]
+ $Size,
+
+ [Switch]
+ $UseRegex
+ )
+
+ $dir_obj = New-Object -TypeName System.IO.DirectoryInfo -ArgumentList $Path
+ if ([Int32]$dir_obj.Attributes -eq -1) {
+ $Module.Warn("Argument path '$Path' does not exist, skipping")
+ return
+ } elseif (-not $dir_obj.Attributes.HasFlag([System.IO.FileAttributes]::Directory)) {
+ $Module.Warn("Argument path '$Path' is a file not a directory, skipping")
+ return
+ }
+
+ $dir_files = @()
+ try {
+ $dir_files = $dir_obj.EnumerateFileSystemInfos("*", [System.IO.SearchOption]::TopDirectoryOnly)
+ } catch [System.IO.DirectoryNotFoundException] { # Broken ReparsePoint/Symlink, cannot enumerate
+ } catch [System.UnauthorizedAccessException] {} # No ListDirectory permissions, Get-ChildItem ignored this
+
+ foreach ($dir_child in $dir_files) {
+ if ($dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Directory) -and $Recurse) {
+ if ($Follow -or -not $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::ReparsePoint)) {
+ $PSBoundParameters.Remove('Path') > $null
+ Search-Path -Path $dir_child.FullName @PSBoundParameters
+ }
+ }
+
+ # Check to see if we've already encountered this path and skip if we have.
+ if (-not $CheckedPaths.Add($dir_child.FullName.ToLowerInvariant())) {
+ continue
+ }
+
+ $Module.Result.examined++
+
+ if ($PSBoundParameters.ContainsKey('Age')) {
+ $age_match = Assert-Age -File $dir_child -Age $Age -AgeStamp $AgeStamp
+ } else {
+ $age_match = $true
+ }
+
+ $file_type_match = Assert-FileType -File $dir_child -FileType $FileType
+ $hidden_match = Assert-FileHidden -File $dir_child -IsHidden:$IsHidden
+
+ if ($PSBoundParameters.ContainsKey('Patterns')) {
+ $pattern_match = Assert-FileNamePattern -File $dir_child -Patterns $Patterns -UseRegex:$UseRegex.IsPresent
+ } else {
+ $pattern_match = $true
+ }
+
+ if ($PSBoundParameters.ContainsKey('Size')) {
+ $size_match = Assert-FileSize -File $dir_child -Size $Size
+ } else {
+ $size_match = $true
+ }
+
+ if (-not ($age_match -and $file_type_match -and $hidden_match -and $pattern_match -and $size_match)) {
+ continue
+ }
+
+ # It passed all our filters so add it
+ $module.Result.matched++
+
+ # TODO: Make this generic so it can be shared with win_find and win_stat.
+ $epoch = New-Object -Type System.DateTime -ArgumentList 1970, 1, 1, 0, 0, 0, 0
+ $file_info = @{
+ attributes = $dir_child.Attributes.ToString()
+ checksum = $null
+ creationtime = (New-TimeSpan -Start $epoch -End $dir_child.CreationTime).TotalSeconds
+ exists = $true
+ extension = $null
+ filename = $dir_child.Name
+ isarchive = $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Archive)
+ isdir = $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Directory)
+ ishidden = $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Hidden)
+ isreadonly = $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::ReadOnly)
+ isreg = $false
+ isshared = $false
+ lastaccesstime = (New-TimeSpan -Start $epoch -End $dir_child.LastAccessTime).TotalSeconds
+ lastwritetime = (New-TimeSpan -Start $epoch -End $dir_child.LastWriteTime).TotalSeconds
+ owner = $null
+ path = $dir_child.FullName
+ sharename = $null
+ size = $null
+ }
+
+ try {
+ $file_info.owner = $dir_child.GetAccessControl().Owner
+ } catch {} # May not have rights to get the Owner, historical behaviour is to ignore.
+
+ if ($dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Directory)) {
+ $share_info = Get-CimInstance -ClassName Win32_Share -Filter "Path='$($dir_child.FullName -replace '\\', '\\')'"
+ if ($null -ne $share_info) {
+ $file_info.isshared = $true
+ $file_info.sharename = $share_info.Name
+ }
+ } else {
+ $file_info.extension = $dir_child.Extension
+ $file_info.isreg = $true
+ $file_info.size = $dir_child.Length
+
+ if ($GetChecksum) {
+ try {
+ $file_info.checksum = Get-FileChecksum -Path $dir_child.FullName -Algorithm $checksum_algorithm
+ } catch {} # Just keep the checksum as $null in the case of a failure.
+ }
+ }
+
+ # Append the link information if the path is a link
+ $link_info = @{
+ isjunction = $false
+ islnk = $false
+ nlink = 1
+ lnk_source = $null
+ lnk_target = $null
+ hlnk_targets = @()
+ }
+ $link_stat = Get-Link -link_path $dir_child.FullName
+ if ($null -ne $link_stat) {
+ switch ($link_stat.Type) {
+ "SymbolicLink" {
+ $link_info.islnk = $true
+ $link_info.isreg = $false
+ $link_info.lnk_source = $link_stat.AbsolutePath
+ $link_info.lnk_target = $link_stat.TargetPath
+ break
+ }
+ "JunctionPoint" {
+ $link_info.isjunction = $true
+ $link_info.isreg = $false
+ $link_info.lnk_source = $link_stat.AbsolutePath
+ $link_info.lnk_target = $link_stat.TargetPath
+ break
+ }
+ "HardLink" {
+ $link_info.nlink = $link_stat.HardTargets.Count
+
+ # remove current path from the targets
+ $hlnk_targets = $link_info.HardTargets | Where-Object { $_ -ne $dir_child.FullName }
+ $link_info.hlnk_targets = @($hlnk_targets)
+ break
+ }
+ }
+ }
+ foreach ($kv in $link_info.GetEnumerator()) {
+ $file_info.$($kv.Key) = $kv.Value
+ }
+
+ # Output the file_info object
+ $file_info
+ }
+}
+
+$search_params = @{
+ CheckedPaths = [System.Collections.Generic.HashSet`1[System.String]]@()
+ GetChecksum = $get_checksum
+ Module = $module
+ FileType = $file_type
+ Follow = $follow
+ IsHidden = $hidden
+ Recurse = $recurse
+}
+
+if ($null -ne $age) {
+ $seconds_per_unit = @{'s'=1; 'm'=60; 'h'=3600; 'd'=86400; 'w'=604800}
+ $seconds_pattern = '^(-?\d+)(s|m|h|d|w)?$'
+ $match = $age -match $seconds_pattern
+ if ($Match) {
+ $specified_seconds = [Int64]$Matches[1]
+ if ($null -eq $Matches[2]) {
+ $chosen_unit = 's'
+ } else {
+ $chosen_unit = $Matches[2]
+ }
+
+ $total_seconds = $specified_seconds * ($seconds_per_unit.$chosen_unit)
+
+ if ($total_seconds -ge 0) {
+ $search_params.Age = (Get-Date).AddSeconds($total_seconds * -1).Ticks
+ } else {
+ # Make sure we add the positive value of seconds to current time then make it negative for later comparisons.
+ $age = (Get-Date).AddSeconds($total_seconds).Ticks
+ $search_params.Age = $age * -1
+ }
+ $search_params.AgeStamp = $age_stamp
+ } else {
+ $module.FailJson("Invalid age pattern specified")
+ }
+}
+
+if ($null -ne $patterns) {
+ $search_params.Patterns = $patterns
+ $search_params.UseRegex = $use_regex
+}
+
+if ($null -ne $size) {
+ $bytes_per_unit = @{'b'=1; 'k'=1KB; 'm'=1MB; 'g'=1GB;'t'=1TB}
+ $size_pattern = '^(-?\d+)(b|k|m|g|t)?$'
+ $match = $size -match $size_pattern
+ if ($Match) {
+ $specified_size = [Int64]$Matches[1]
+ if ($null -eq $Matches[2]) {
+ $chosen_byte = 'b'
+ } else {
+ $chosen_byte = $Matches[2]
+ }
+
+ $search_params.Size = $specified_size * ($bytes_per_unit.$chosen_byte)
+ } else {
+ $module.FailJson("Invalid size pattern specified")
+ }
+}
+
+$matched_files = foreach ($path in $paths) {
+ # Ensure we pass in an absolute path. We use the ExecutionContext as this is based on the PSProvider path not the
+ # process location which can be different.
+ $abs_path = $ExecutionContext.SessionState.Path.GetUnresolvedProviderPathFromPSPath($path)
+ Search-Path -Path $abs_path @search_params
+}
+
+# Make sure we sort the files in alphabetical order.
+$module.Result.files = @() + ($matched_files | Sort-Object -Property {$_.path})
+
+$module.ExitJson()
+
diff --git a/test/support/windows-integration/plugins/modules/win_find.py b/test/support/windows-integration/plugins/modules/win_find.py
new file mode 100644
index 00000000..f506f956
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_find.py
@@ -0,0 +1,345 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2016, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub. actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_find
+version_added: "2.3"
+short_description: Return a list of files based on specific criteria
+description:
+ - Return a list of files based on specified criteria.
+ - Multiple criteria are AND'd together.
+ - For non-Windows targets, use the M(find) module instead.
+options:
+ age:
+ description:
+ - Select files or folders whose age is equal to or greater than
+ the specified time.
+ - Use a negative age to find files equal to or less than
+ the specified time.
+ - You can choose seconds, minutes, hours, days or weeks
+ by specifying the first letter of an of
+ those words (e.g., "2s", "10d", 1w").
+ type: str
+ age_stamp:
+ description:
+ - Choose the file property against which we compare C(age).
+ - The default attribute we compare with is the last modification time.
+ type: str
+ choices: [ atime, ctime, mtime ]
+ default: mtime
+ checksum_algorithm:
+ description:
+ - Algorithm to determine the checksum of a file.
+ - Will throw an error if the host is unable to use specified algorithm.
+ type: str
+ choices: [ md5, sha1, sha256, sha384, sha512 ]
+ default: sha1
+ file_type:
+ description: Type of file to search for.
+ type: str
+ choices: [ directory, file ]
+ default: file
+ follow:
+ description:
+ - Set this to C(yes) to follow symlinks in the path.
+ - This needs to be used in conjunction with C(recurse).
+ type: bool
+ default: no
+ get_checksum:
+ description:
+ - Whether to return a checksum of the file in the return info (default sha1),
+ use C(checksum_algorithm) to change from the default.
+ type: bool
+ default: yes
+ hidden:
+ description: Set this to include hidden files or folders.
+ type: bool
+ default: no
+ paths:
+ description:
+ - List of paths of directories to search for files or folders in.
+ - This can be supplied as a single path or a list of paths.
+ type: list
+ required: yes
+ patterns:
+ description:
+ - One or more (powershell or regex) patterns to compare filenames with.
+ - The type of pattern matching is controlled by C(use_regex) option.
+ - The patterns restrict the list of files or folders to be returned based on the filenames.
+ - For a file to be matched it only has to match with one pattern in a list provided.
+ type: list
+ aliases: [ "regex", "regexp" ]
+ recurse:
+ description:
+ - Will recursively descend into the directory looking for files or folders.
+ type: bool
+ default: no
+ size:
+ description:
+ - Select files or folders whose size is equal to or greater than the specified size.
+ - Use a negative value to find files equal to or less than the specified size.
+ - You can specify the size with a suffix of the byte type i.e. kilo = k, mega = m...
+ - Size is not evaluated for symbolic links.
+ type: str
+ use_regex:
+ description:
+ - Will set patterns to run as a regex check if set to C(yes).
+ type: bool
+ default: no
+author:
+- Jordan Borean (@jborean93)
+'''
+
+EXAMPLES = r'''
+- name: Find files in path
+ win_find:
+ paths: D:\Temp
+
+- name: Find hidden files in path
+ win_find:
+ paths: D:\Temp
+ hidden: yes
+
+- name: Find files in multiple paths
+ win_find:
+ paths:
+ - C:\Temp
+ - D:\Temp
+
+- name: Find files in directory while searching recursively
+ win_find:
+ paths: D:\Temp
+ recurse: yes
+
+- name: Find files in directory while following symlinks
+ win_find:
+ paths: D:\Temp
+ recurse: yes
+ follow: yes
+
+- name: Find files with .log and .out extension using powershell wildcards
+ win_find:
+ paths: D:\Temp
+ patterns: [ '*.log', '*.out' ]
+
+- name: Find files in path based on regex pattern
+ win_find:
+ paths: D:\Temp
+ patterns: out_\d{8}-\d{6}.log
+
+- name: Find files older than 1 day
+ win_find:
+ paths: D:\Temp
+ age: 86400
+
+- name: Find files older than 1 day based on create time
+ win_find:
+ paths: D:\Temp
+ age: 86400
+ age_stamp: ctime
+
+- name: Find files older than 1 day with unit syntax
+ win_find:
+ paths: D:\Temp
+ age: 1d
+
+- name: Find files newer than 1 hour
+ win_find:
+ paths: D:\Temp
+ age: -3600
+
+- name: Find files newer than 1 hour with unit syntax
+ win_find:
+ paths: D:\Temp
+ age: -1h
+
+- name: Find files larger than 1MB
+ win_find:
+ paths: D:\Temp
+ size: 1048576
+
+- name: Find files larger than 1GB with unit syntax
+ win_find:
+ paths: D:\Temp
+ size: 1g
+
+- name: Find files smaller than 1MB
+ win_find:
+ paths: D:\Temp
+ size: -1048576
+
+- name: Find files smaller than 1GB with unit syntax
+ win_find:
+ paths: D:\Temp
+ size: -1g
+
+- name: Find folders/symlinks in multiple paths
+ win_find:
+ paths:
+ - C:\Temp
+ - D:\Temp
+ file_type: directory
+
+- name: Find files and return SHA256 checksum of files found
+ win_find:
+ paths: C:\Temp
+ get_checksum: yes
+ checksum_algorithm: sha256
+
+- name: Find files and do not return the checksum
+ win_find:
+ paths: C:\Temp
+ get_checksum: no
+'''
+
+RETURN = r'''
+examined:
+ description: The number of files/folders that was checked.
+ returned: always
+ type: int
+ sample: 10
+matched:
+ description: The number of files/folders that match the criteria.
+ returned: always
+ type: int
+ sample: 2
+files:
+ description: Information on the files/folders that match the criteria returned as a list of dictionary elements
+ for each file matched. The entries are sorted by the path value alphabetically.
+ returned: success
+ type: complex
+ contains:
+ attributes:
+ description: attributes of the file at path in raw form.
+ returned: success, path exists
+ type: str
+ sample: "Archive, Hidden"
+ checksum:
+ description: The checksum of a file based on checksum_algorithm specified.
+ returned: success, path exists, path is a file, get_checksum == True
+ type: str
+ sample: 09cb79e8fc7453c84a07f644e441fd81623b7f98
+ creationtime:
+ description: The create time of the file represented in seconds since epoch.
+ returned: success, path exists
+ type: float
+ sample: 1477984205.15
+ exists:
+ description: Whether the file exists, will always be true for M(win_find).
+ returned: success, path exists
+ type: bool
+ sample: true
+ extension:
+ description: The extension of the file at path.
+ returned: success, path exists, path is a file
+ type: str
+ sample: ".ps1"
+ filename:
+ description: The name of the file.
+ returned: success, path exists
+ type: str
+ sample: temp
+ hlnk_targets:
+ description: List of other files pointing to the same file (hard links), excludes the current file.
+ returned: success, path exists
+ type: list
+ sample:
+ - C:\temp\file.txt
+ - C:\Windows\update.log
+ isarchive:
+ description: If the path is ready for archiving or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isdir:
+ description: If the path is a directory or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ ishidden:
+ description: If the path is hidden or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isjunction:
+ description: If the path is a junction point.
+ returned: success, path exists
+ type: bool
+ sample: true
+ islnk:
+ description: If the path is a symbolic link.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isreadonly:
+ description: If the path is read only or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isreg:
+ description: If the path is a regular file or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isshared:
+ description: If the path is shared or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ lastaccesstime:
+ description: The last access time of the file represented in seconds since epoch.
+ returned: success, path exists
+ type: float
+ sample: 1477984205.15
+ lastwritetime:
+ description: The last modification time of the file represented in seconds since epoch.
+ returned: success, path exists
+ type: float
+ sample: 1477984205.15
+ lnk_source:
+ description: The target of the symlink normalized for the remote filesystem.
+ returned: success, path exists, path is a symbolic link or junction point
+ type: str
+ sample: C:\temp
+ lnk_target:
+ description: The target of the symlink. Note that relative paths remain relative, will return null if not a link.
+ returned: success, path exists, path is a symbolic link or junction point
+ type: str
+ sample: temp
+ nlink:
+ description: Number of links to the file (hard links)
+ returned: success, path exists
+ type: int
+ sample: 1
+ owner:
+ description: The owner of the file.
+ returned: success, path exists
+ type: str
+ sample: BUILTIN\Administrators
+ path:
+ description: The full absolute path to the file.
+ returned: success, path exists
+ type: str
+ sample: BUILTIN\Administrators
+ sharename:
+ description: The name of share if folder is shared.
+ returned: success, path exists, path is a directory and isshared == True
+ type: str
+ sample: file-share
+ size:
+ description: The size in bytes of the file.
+ returned: success, path exists, path is a file
+ type: int
+ sample: 1024
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_format.ps1 b/test/support/windows-integration/plugins/modules/win_format.ps1
new file mode 100644
index 00000000..b5fd3ae0
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_format.ps1
@@ -0,0 +1,200 @@
+#!powershell
+
+# Copyright: (c) 2019, Varun Chopra (@chopraaa) <v@chopraaa.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+#AnsibleRequires -OSVersion 6.2
+
+Set-StrictMode -Version 2
+
+$ErrorActionPreference = "Stop"
+
+$spec = @{
+ options = @{
+ drive_letter = @{ type = "str" }
+ path = @{ type = "str" }
+ label = @{ type = "str" }
+ new_label = @{ type = "str" }
+ file_system = @{ type = "str"; choices = "ntfs", "refs", "exfat", "fat32", "fat" }
+ allocation_unit_size = @{ type = "int" }
+ large_frs = @{ type = "bool" }
+ full = @{ type = "bool"; default = $false }
+ compress = @{ type = "bool" }
+ integrity_streams = @{ type = "bool" }
+ force = @{ type = "bool"; default = $false }
+ }
+ mutually_exclusive = @(
+ ,@('drive_letter', 'path', 'label')
+ )
+ required_one_of = @(
+ ,@('drive_letter', 'path', 'label')
+ )
+ supports_check_mode = $true
+}
+
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec)
+
+$drive_letter = $module.Params.drive_letter
+$path = $module.Params.path
+$label = $module.Params.label
+$new_label = $module.Params.new_label
+$file_system = $module.Params.file_system
+$allocation_unit_size = $module.Params.allocation_unit_size
+$large_frs = $module.Params.large_frs
+$full_format = $module.Params.full
+$compress_volume = $module.Params.compress
+$integrity_streams = $module.Params.integrity_streams
+$force_format = $module.Params.force
+
+# Some pre-checks
+if ($null -ne $drive_letter -and $drive_letter -notmatch "^[a-zA-Z]$") {
+ $module.FailJson("The parameter drive_letter should be a single character A-Z")
+}
+if ($integrity_streams -eq $true -and $file_system -ne "refs") {
+ $module.FailJson("Integrity streams can be enabled only on ReFS volumes. You specified: $($file_system)")
+}
+if ($compress_volume -eq $true) {
+ if ($file_system -eq "ntfs") {
+ if ($null -ne $allocation_unit_size -and $allocation_unit_size -gt 4096) {
+ $module.FailJson("NTFS compression is not supported for allocation unit sizes above 4096")
+ }
+ }
+ else {
+ $module.FailJson("Compression can be enabled only on NTFS volumes. You specified: $($file_system)")
+ }
+}
+
+function Get-AnsibleVolume {
+ param(
+ $DriveLetter,
+ $Path,
+ $Label
+ )
+
+ if ($null -ne $DriveLetter) {
+ try {
+ $volume = Get-Volume -DriveLetter $DriveLetter
+ } catch {
+ $module.FailJson("There was an error retrieving the volume using drive_letter $($DriveLetter): $($_.Exception.Message)", $_)
+ }
+ }
+ elseif ($null -ne $Path) {
+ try {
+ $volume = Get-Volume -Path $Path
+ } catch {
+ $module.FailJson("There was an error retrieving the volume using path $($Path): $($_.Exception.Message)", $_)
+ }
+ }
+ elseif ($null -ne $Label) {
+ try {
+ $volume = Get-Volume -FileSystemLabel $Label
+ } catch {
+ $module.FailJson("There was an error retrieving the volume using label $($Label): $($_.Exception.Message)", $_)
+ }
+ }
+ else {
+ $module.FailJson("Unable to locate volume: drive_letter, path and label were not specified")
+ }
+
+ return $volume
+}
+
+function Format-AnsibleVolume {
+ param(
+ $Path,
+ $Label,
+ $FileSystem,
+ $Full,
+ $UseLargeFRS,
+ $Compress,
+ $SetIntegrityStreams,
+ $AllocationUnitSize
+ )
+ $parameters = @{
+ Path = $Path
+ Full = $Full
+ }
+ if ($null -ne $UseLargeFRS) {
+ $parameters.Add("UseLargeFRS", $UseLargeFRS)
+ }
+ if ($null -ne $SetIntegrityStreams) {
+ $parameters.Add("SetIntegrityStreams", $SetIntegrityStreams)
+ }
+ if ($null -ne $Compress){
+ $parameters.Add("Compress", $Compress)
+ }
+ if ($null -ne $Label) {
+ $parameters.Add("NewFileSystemLabel", $Label)
+ }
+ if ($null -ne $FileSystem) {
+ $parameters.Add("FileSystem", $FileSystem)
+ }
+ if ($null -ne $AllocationUnitSize) {
+ $parameters.Add("AllocationUnitSize", $AllocationUnitSize)
+ }
+
+ Format-Volume @parameters -Confirm:$false | Out-Null
+
+}
+
+$ansible_volume = Get-AnsibleVolume -DriveLetter $drive_letter -Path $path -Label $label
+$ansible_file_system = $ansible_volume.FileSystem
+$ansible_volume_size = $ansible_volume.Size
+$ansible_volume_alu = (Get-CimInstance -ClassName Win32_Volume -Filter "DeviceId = '$($ansible_volume.path.replace('\','\\'))'" -Property BlockSize).BlockSize
+
+$ansible_partition = Get-Partition -Volume $ansible_volume
+
+if (-not $force_format -and $null -ne $allocation_unit_size -and $ansible_volume_alu -ne 0 -and $null -ne $ansible_volume_alu -and $allocation_unit_size -ne $ansible_volume_alu) {
+ $module.FailJson("Force format must be specified since target allocation unit size: $($allocation_unit_size) is different from the current allocation unit size of the volume: $($ansible_volume_alu)")
+}
+
+foreach ($access_path in $ansible_partition.AccessPaths) {
+ if ($access_path -ne $Path) {
+ if ($null -ne $file_system -and
+ -not [string]::IsNullOrEmpty($ansible_file_system) -and
+ $file_system -ne $ansible_file_system)
+ {
+ if (-not $force_format)
+ {
+ $no_files_in_volume = (Get-ChildItem -LiteralPath $access_path -ErrorAction SilentlyContinue | Measure-Object).Count -eq 0
+ if($no_files_in_volume)
+ {
+ $module.FailJson("Force format must be specified since target file system: $($file_system) is different from the current file system of the volume: $($ansible_file_system.ToLower())")
+ }
+ else
+ {
+ $module.FailJson("Force format must be specified to format non-pristine volumes")
+ }
+ }
+ }
+ else
+ {
+ $pristine = -not $force_format
+ }
+ }
+}
+
+if ($force_format) {
+ if (-not $module.CheckMode) {
+ Format-AnsibleVolume -Path $ansible_volume.Path -Full $full_format -Label $new_label -FileSystem $file_system -SetIntegrityStreams $integrity_streams -UseLargeFRS $large_frs -Compress $compress_volume -AllocationUnitSize $allocation_unit_size
+ }
+ $module.Result.changed = $true
+}
+else {
+ if ($pristine) {
+ if ($null -eq $new_label) {
+ $new_label = $ansible_volume.FileSystemLabel
+ }
+ # Conditions for formatting
+ if ($ansible_volume_size -eq 0 -or
+ $ansible_volume.FileSystemLabel -ne $new_label) {
+ if (-not $module.CheckMode) {
+ Format-AnsibleVolume -Path $ansible_volume.Path -Full $full_format -Label $new_label -FileSystem $file_system -SetIntegrityStreams $integrity_streams -UseLargeFRS $large_frs -Compress $compress_volume -AllocationUnitSize $allocation_unit_size
+ }
+ $module.Result.changed = $true
+ }
+ }
+}
+
+$module.ExitJson()
diff --git a/test/support/windows-integration/plugins/modules/win_format.py b/test/support/windows-integration/plugins/modules/win_format.py
new file mode 100644
index 00000000..f8f18ed7
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_format.py
@@ -0,0 +1,103 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2019, Varun Chopra (@chopraaa) <v@chopraaa.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {
+ 'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'
+}
+
+DOCUMENTATION = r'''
+module: win_format
+version_added: '2.8'
+short_description: Formats an existing volume or a new volume on an existing partition on Windows
+description:
+ - The M(win_format) module formats an existing volume or a new volume on an existing partition on Windows
+options:
+ drive_letter:
+ description:
+ - Used to specify the drive letter of the volume to be formatted.
+ type: str
+ path:
+ description:
+ - Used to specify the path to the volume to be formatted.
+ type: str
+ label:
+ description:
+ - Used to specify the label of the volume to be formatted.
+ type: str
+ new_label:
+ description:
+ - Used to specify the new file system label of the formatted volume.
+ type: str
+ file_system:
+ description:
+ - Used to specify the file system to be used when formatting the target volume.
+ type: str
+ choices: [ ntfs, refs, exfat, fat32, fat ]
+ allocation_unit_size:
+ description:
+ - Specifies the cluster size to use when formatting the volume.
+ - If no cluster size is specified when you format a partition, defaults are selected based on
+ the size of the partition.
+ - This value must be a multiple of the physical sector size of the disk.
+ type: int
+ large_frs:
+ description:
+ - Specifies that large File Record System (FRS) should be used.
+ type: bool
+ compress:
+ description:
+ - Enable compression on the resulting NTFS volume.
+ - NTFS compression is not supported where I(allocation_unit_size) is more than 4096.
+ type: bool
+ integrity_streams:
+ description:
+ - Enable integrity streams on the resulting ReFS volume.
+ type: bool
+ full:
+ description:
+ - A full format writes to every sector of the disk, takes much longer to perform than the
+ default (quick) format, and is not recommended on storage that is thinly provisioned.
+ - Specify C(true) for full format.
+ type: bool
+ force:
+ description:
+ - Specify if formatting should be forced for volumes that are not created from new partitions
+ or if the source and target file system are different.
+ type: bool
+notes:
+ - Microsoft Windows Server 2012 or Microsoft Windows 8 or newer is required to use this module. To check if your system is compatible, see
+ U(https://docs.microsoft.com/en-us/windows/desktop/sysinfo/operating-system-version).
+ - One of three parameters (I(drive_letter), I(path) and I(label)) are mandatory to identify the target
+ volume but more than one cannot be specified at the same time.
+ - This module is idempotent if I(force) is not specified and file system labels remain preserved.
+ - For more information, see U(https://docs.microsoft.com/en-us/previous-versions/windows/desktop/stormgmt/format-msft-volume)
+seealso:
+ - module: win_disk_facts
+ - module: win_partition
+author:
+ - Varun Chopra (@chopraaa) <v@chopraaa.com>
+'''
+
+EXAMPLES = r'''
+- name: Create a partition with drive letter D and size 5 GiB
+ win_partition:
+ drive_letter: D
+ partition_size: 5 GiB
+ disk_number: 1
+
+- name: Full format the newly created partition as NTFS and label it
+ win_format:
+ drive_letter: D
+ file_system: NTFS
+ new_label: Formatted
+ full: True
+'''
+
+RETURN = r'''
+#
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_get_url.ps1 b/test/support/windows-integration/plugins/modules/win_get_url.ps1
new file mode 100644
index 00000000..1d8dd5a3
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_get_url.ps1
@@ -0,0 +1,274 @@
+#!powershell
+
+# Copyright: (c) 2015, Paul Durivage <paul.durivage@rackspace.com>
+# Copyright: (c) 2015, Tal Auslander <tal@cloudshare.com>
+# Copyright: (c) 2017, Dag Wieers <dag@wieers.com>
+# Copyright: (c) 2019, Viktor Utkin <viktor_utkin@epam.com>
+# Copyright: (c) 2019, Uladzimir Klybik <uladzimir_klybik@epam.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+#Requires -Module Ansible.ModuleUtils.FileUtil
+#Requires -Module Ansible.ModuleUtils.WebRequest
+
+$spec = @{
+ options = @{
+ url = @{ type="str"; required=$true }
+ dest = @{ type='path'; required=$true }
+ force = @{ type='bool'; default=$true }
+ checksum = @{ type='str' }
+ checksum_algorithm = @{ type='str'; default='sha1'; choices = @("md5", "sha1", "sha256", "sha384", "sha512") }
+ checksum_url = @{ type='str' }
+
+ # Defined for the alias backwards compatibility, remove once aliases are removed
+ url_username = @{
+ aliases = @("user", "username")
+ deprecated_aliases = @(
+ @{ name = "user"; version = "2.14" },
+ @{ name = "username"; version = "2.14" }
+ )
+ }
+ url_password = @{
+ aliases = @("password")
+ deprecated_aliases = @(
+ @{ name = "password"; version = "2.14" }
+ )
+ }
+ }
+ mutually_exclusive = @(
+ ,@('checksum', 'checksum_url')
+ )
+ supports_check_mode = $true
+}
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWebRequestSpec))
+
+$url = $module.Params.url
+$dest = $module.Params.dest
+$force = $module.Params.force
+$checksum = $module.Params.checksum
+$checksum_algorithm = $module.Params.checksum_algorithm
+$checksum_url = $module.Params.checksum_url
+
+$module.Result.elapsed = 0
+$module.Result.url = $url
+
+Function Get-ChecksumFromUri {
+ param(
+ [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module,
+ [Parameter(Mandatory=$true)][Uri]$Uri,
+ [Uri]$SourceUri
+ )
+
+ $script = {
+ param($Response, $Stream)
+
+ $read_stream = New-Object -TypeName System.IO.StreamReader -ArgumentList $Stream
+ $web_checksum = $read_stream.ReadToEnd()
+ $basename = (Split-Path -Path $SourceUri.LocalPath -Leaf)
+ $basename = [regex]::Escape($basename)
+ $web_checksum_str = $web_checksum -split '\r?\n' | Select-String -Pattern $("\s+\.?\/?\\?" + $basename + "\s*$")
+ if (-not $web_checksum_str) {
+ $Module.FailJson("Checksum record not found for file name '$basename' in file from url: '$Uri'")
+ }
+
+ $web_checksum_str_splitted = $web_checksum_str[0].ToString().split(" ", 2)
+ $hash_from_file = $web_checksum_str_splitted[0].Trim()
+ # Remove any non-alphanumeric characters
+ $hash_from_file = $hash_from_file -replace '\W+', ''
+
+ Write-Output -InputObject $hash_from_file
+ }
+ $web_request = Get-AnsibleWebRequest -Uri $Uri -Module $Module
+
+ try {
+ Invoke-WithWebRequest -Module $Module -Request $web_request -Script $script
+ } catch {
+ $Module.FailJson("Error when getting the remote checksum from '$Uri'. $($_.Exception.Message)", $_)
+ }
+}
+
+Function Compare-ModifiedFile {
+ <#
+ .SYNOPSIS
+ Compares the remote URI resource against the local Dest resource. Will
+ return true if the LastWriteTime/LastModificationDate of the remote is
+ newer than the local resource date.
+ #>
+ param(
+ [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module,
+ [Parameter(Mandatory=$true)][Uri]$Uri,
+ [Parameter(Mandatory=$true)][String]$Dest
+ )
+
+ $dest_last_mod = (Get-AnsibleItem -Path $Dest).LastWriteTimeUtc
+
+ # If the URI is a file we don't need to go through the whole WebRequest
+ if ($Uri.IsFile) {
+ $src_last_mod = (Get-AnsibleItem -Path $Uri.AbsolutePath).LastWriteTimeUtc
+ } else {
+ $web_request = Get-AnsibleWebRequest -Uri $Uri -Module $Module
+ $web_request.Method = switch ($web_request.GetType().Name) {
+ FtpWebRequest { [System.Net.WebRequestMethods+Ftp]::GetDateTimestamp }
+ HttpWebRequest { [System.Net.WebRequestMethods+Http]::Head }
+ }
+ $script = { param($Response, $Stream); $Response.LastModified }
+
+ try {
+ $src_last_mod = Invoke-WithWebRequest -Module $Module -Request $web_request -Script $script
+ } catch {
+ $Module.FailJson("Error when requesting 'Last-Modified' date from '$Uri'. $($_.Exception.Message)", $_)
+ }
+ }
+
+ # Return $true if the Uri LastModification date is newer than the Dest LastModification date
+ ((Get-Date -Date $src_last_mod).ToUniversalTime() -gt $dest_last_mod)
+}
+
+Function Get-Checksum {
+ param(
+ [Parameter(Mandatory=$true)][String]$Path,
+ [String]$Algorithm = "sha1"
+ )
+
+ switch ($Algorithm) {
+ 'md5' { $sp = New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider }
+ 'sha1' { $sp = New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider }
+ 'sha256' { $sp = New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider }
+ 'sha384' { $sp = New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider }
+ 'sha512' { $sp = New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider }
+ }
+
+ $fs = [System.IO.File]::Open($Path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read,
+ [System.IO.FileShare]::ReadWrite)
+ try {
+ $hash = [System.BitConverter]::ToString($sp.ComputeHash($fs)).Replace("-", "").ToLower()
+ } finally {
+ $fs.Dispose()
+ }
+ return $hash
+}
+
+Function Invoke-DownloadFile {
+ param(
+ [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module,
+ [Parameter(Mandatory=$true)][Uri]$Uri,
+ [Parameter(Mandatory=$true)][String]$Dest,
+ [String]$Checksum,
+ [String]$ChecksumAlgorithm
+ )
+
+ # Check $dest parent folder exists before attempting download, which avoids unhelpful generic error message.
+ $dest_parent = Split-Path -LiteralPath $Dest
+ if (-not (Test-Path -LiteralPath $dest_parent -PathType Container)) {
+ $module.FailJson("The path '$dest_parent' does not exist for destination '$Dest', or is not visible to the current user. Ensure download destination folder exists (perhaps using win_file state=directory) before win_get_url runs.")
+ }
+
+ $download_script = {
+ param($Response, $Stream)
+
+ # Download the file to a temporary directory so we can compare it
+ $tmp_dest = Join-Path -Path $Module.Tmpdir -ChildPath ([System.IO.Path]::GetRandomFileName())
+ $fs = [System.IO.File]::Create($tmp_dest)
+ try {
+ $Stream.CopyTo($fs)
+ $fs.Flush()
+ } finally {
+ $fs.Dispose()
+ }
+ $tmp_checksum = Get-Checksum -Path $tmp_dest -Algorithm $ChecksumAlgorithm
+ $Module.Result.checksum_src = $tmp_checksum
+
+ # If the checksum has been set, verify the checksum of the remote against the input checksum.
+ if ($Checksum -and $Checksum -ne $tmp_checksum) {
+ $Module.FailJson(("The checksum for {0} did not match '{1}', it was '{2}'" -f $Uri, $Checksum, $tmp_checksum))
+ }
+
+ $download = $true
+ if (Test-Path -LiteralPath $Dest) {
+ # Validate the remote checksum against the existing downloaded file
+ $dest_checksum = Get-Checksum -Path $Dest -Algorithm $ChecksumAlgorithm
+
+ # If we don't need to download anything, save the dest checksum so we don't waste time calculating it
+ # again at the end of the script
+ if ($dest_checksum -eq $tmp_checksum) {
+ $download = $false
+ $Module.Result.checksum_dest = $dest_checksum
+ $Module.Result.size = (Get-AnsibleItem -Path $Dest).Length
+ }
+ }
+
+ if ($download) {
+ Copy-Item -LiteralPath $tmp_dest -Destination $Dest -Force -WhatIf:$Module.CheckMode > $null
+ $Module.Result.changed = $true
+ }
+ }
+ $web_request = Get-AnsibleWebRequest -Uri $Uri -Module $Module
+
+ try {
+ Invoke-WithWebRequest -Module $Module -Request $web_request -Script $download_script
+ } catch {
+ $Module.FailJson("Error downloading '$Uri' to '$Dest': $($_.Exception.Message)", $_)
+ }
+}
+
+# Use last part of url for dest file name if a directory is supplied for $dest
+if (Test-Path -LiteralPath $dest -PathType Container) {
+ $uri = [System.Uri]$url
+ $basename = Split-Path -Path $uri.LocalPath -Leaf
+ if ($uri.LocalPath -and $uri.LocalPath -ne '/' -and $basename) {
+ $url_basename = Split-Path -Path $uri.LocalPath -Leaf
+ $dest = Join-Path -Path $dest -ChildPath $url_basename
+ } else {
+ $dest = Join-Path -Path $dest -ChildPath $uri.Host
+ }
+
+ # Ensure we have a string instead of a PS object to avoid serialization issues
+ $dest = $dest.ToString()
+} elseif (([System.IO.Path]::GetFileName($dest)) -eq '') {
+ # We have a trailing path separator
+ $module.FailJson("The destination path '$dest' does not exist, or is not visible to the current user. Ensure download destination folder exists (perhaps using win_file state=directory) before win_get_url runs.")
+}
+
+$module.Result.dest = $dest
+
+if ($checksum) {
+ $checksum = $checksum.Trim().ToLower()
+}
+if ($checksum_algorithm) {
+ $checksum_algorithm = $checksum_algorithm.Trim().ToLower()
+}
+if ($checksum_url) {
+ $checksum_url = $checksum_url.Trim()
+}
+
+# Check for case $checksum variable contain url. If yes, get file data from url and replace original value in $checksum
+if ($checksum_url) {
+ $checksum_uri = [System.Uri]$checksum_url
+ if ($checksum_uri.Scheme -notin @("file", "ftp", "http", "https")) {
+ $module.FailJson("Unsupported 'checksum_url' value for '$dest': '$checksum_url'")
+ }
+
+ $checksum = Get-ChecksumFromUri -Module $Module -Uri $checksum_uri -SourceUri $url
+}
+
+if ($force -or -not (Test-Path -LiteralPath $dest)) {
+ # force=yes or dest does not exist, download the file
+ # Note: Invoke-DownloadFile will compare the checksums internally if dest exists
+ Invoke-DownloadFile -Module $module -Uri $url -Dest $dest -Checksum $checksum `
+ -ChecksumAlgorithm $checksum_algorithm
+} else {
+ # force=no, we want to check the last modified dates and only download if they don't match
+ $is_modified = Compare-ModifiedFile -Module $module -Uri $url -Dest $dest
+ if ($is_modified) {
+ Invoke-DownloadFile -Module $module -Uri $url -Dest $dest -Checksum $checksum `
+ -ChecksumAlgorithm $checksum_algorithm
+ }
+}
+
+if ((-not $module.Result.ContainsKey("checksum_dest")) -and (Test-Path -LiteralPath $dest)) {
+ # Calculate the dest file checksum if it hasn't already been done
+ $module.Result.checksum_dest = Get-Checksum -Path $dest -Algorithm $checksum_algorithm
+ $module.Result.size = (Get-AnsibleItem -Path $dest).Length
+}
+
+$module.ExitJson()
diff --git a/test/support/windows-integration/plugins/modules/win_get_url.py b/test/support/windows-integration/plugins/modules/win_get_url.py
new file mode 100644
index 00000000..ef5b5f97
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_get_url.py
@@ -0,0 +1,215 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com>, and others
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# This is a windows documentation stub. actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_get_url
+version_added: "1.7"
+short_description: Downloads file from HTTP, HTTPS, or FTP to node
+description:
+- Downloads files from HTTP, HTTPS, or FTP to the remote server.
+- The remote server I(must) have direct access to the remote resource.
+- For non-Windows targets, use the M(get_url) module instead.
+options:
+ url:
+ description:
+ - The full URL of a file to download.
+ type: str
+ required: yes
+ dest:
+ description:
+ - The location to save the file at the URL.
+ - Be sure to include a filename and extension as appropriate.
+ type: path
+ required: yes
+ force:
+ description:
+ - If C(yes), will download the file every time and replace the file if the contents change. If C(no), will only
+ download the file if it does not exist or the remote file has been
+ modified more recently than the local file.
+ - This works by sending an http HEAD request to retrieve last modified
+ time of the requested resource, so for this to work, the remote web
+ server must support HEAD requests.
+ type: bool
+ default: yes
+ version_added: "2.0"
+ checksum:
+ description:
+ - If a I(checksum) is passed to this parameter, the digest of the
+ destination file will be calculated after it is downloaded to ensure
+ its integrity and verify that the transfer completed successfully.
+ - This option cannot be set with I(checksum_url).
+ type: str
+ version_added: "2.8"
+ checksum_algorithm:
+ description:
+ - Specifies the hashing algorithm used when calculating the checksum of
+ the remote and destination file.
+ type: str
+ choices:
+ - md5
+ - sha1
+ - sha256
+ - sha384
+ - sha512
+ default: sha1
+ version_added: "2.8"
+ checksum_url:
+ description:
+ - Specifies a URL that contains the checksum values for the resource at
+ I(url).
+ - Like C(checksum), this is used to verify the integrity of the remote
+ transfer.
+ - This option cannot be set with I(checksum).
+ type: str
+ version_added: "2.8"
+ url_username:
+ description:
+ - The username to use for authentication.
+ - The aliases I(user) and I(username) are deprecated and will be removed in
+ Ansible 2.14.
+ aliases:
+ - user
+ - username
+ url_password:
+ description:
+ - The password for I(url_username).
+ - The alias I(password) is deprecated and will be removed in Ansible 2.14.
+ aliases:
+ - password
+ proxy_url:
+ version_added: "2.0"
+ proxy_username:
+ version_added: "2.0"
+ proxy_password:
+ version_added: "2.0"
+ headers:
+ version_added: "2.4"
+ use_proxy:
+ version_added: "2.4"
+ follow_redirects:
+ version_added: "2.9"
+ maximum_redirection:
+ version_added: "2.9"
+ client_cert:
+ version_added: "2.9"
+ client_cert_password:
+ version_added: "2.9"
+ method:
+ description:
+ - This option is not for use with C(win_get_url) and should be ignored.
+ version_added: "2.9"
+notes:
+- If your URL includes an escaped slash character (%2F) this module will convert it to a real slash.
+ This is a result of the behaviour of the System.Uri class as described in
+ L(the documentation,https://docs.microsoft.com/en-us/dotnet/framework/configure-apps/file-schema/network/schemesettings-element-uri-settings#remarks).
+- Since Ansible 2.8, the module will skip reporting a change if the remote
+ checksum is the same as the local local even when C(force=yes). This is to
+ better align with M(get_url).
+extends_documentation_fragment:
+- url_windows
+seealso:
+- module: get_url
+- module: uri
+- module: win_uri
+author:
+- Paul Durivage (@angstwad)
+- Takeshi Kuramochi (@tksarah)
+'''
+
+EXAMPLES = r'''
+- name: Download earthrise.jpg to specified path
+ win_get_url:
+ url: http://www.example.com/earthrise.jpg
+ dest: C:\Users\RandomUser\earthrise.jpg
+
+- name: Download earthrise.jpg to specified path only if modified
+ win_get_url:
+ url: http://www.example.com/earthrise.jpg
+ dest: C:\Users\RandomUser\earthrise.jpg
+ force: no
+
+- name: Download earthrise.jpg to specified path through a proxy server.
+ win_get_url:
+ url: http://www.example.com/earthrise.jpg
+ dest: C:\Users\RandomUser\earthrise.jpg
+ proxy_url: http://10.0.0.1:8080
+ proxy_username: username
+ proxy_password: password
+
+- name: Download file from FTP with authentication
+ win_get_url:
+ url: ftp://server/file.txt
+ dest: '%TEMP%\ftp-file.txt'
+ url_username: ftp-user
+ url_password: ftp-password
+
+- name: Download src with sha256 checksum url
+ win_get_url:
+ url: http://www.example.com/earthrise.jpg
+ dest: C:\temp\earthrise.jpg
+ checksum_url: http://www.example.com/sha256sum.txt
+ checksum_algorithm: sha256
+ force: True
+
+- name: Download src with sha256 checksum url
+ win_get_url:
+ url: http://www.example.com/earthrise.jpg
+ dest: C:\temp\earthrise.jpg
+ checksum: a97e6837f60cec6da4491bab387296bbcd72bdba
+ checksum_algorithm: sha1
+ force: True
+'''
+
+RETURN = r'''
+dest:
+ description: destination file/path
+ returned: always
+ type: str
+ sample: C:\Users\RandomUser\earthrise.jpg
+checksum_dest:
+ description: <algorithm> checksum of the file after the download
+ returned: success and dest has been downloaded
+ type: str
+ sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827
+checksum_src:
+ description: <algorithm> checksum of the remote resource
+ returned: force=yes or dest did not exist
+ type: str
+ sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827
+elapsed:
+ description: The elapsed seconds between the start of poll and the end of the module.
+ returned: always
+ type: float
+ sample: 2.1406487
+size:
+ description: size of the dest file
+ returned: success
+ type: int
+ sample: 1220
+url:
+ description: requested url
+ returned: always
+ type: str
+ sample: http://www.example.com/earthrise.jpg
+msg:
+ description: Error message, or HTTP status message from web-server
+ returned: always
+ type: str
+ sample: OK
+status_code:
+ description: HTTP status code
+ returned: always
+ type: int
+ sample: 200
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_lineinfile.ps1 b/test/support/windows-integration/plugins/modules/win_lineinfile.ps1
new file mode 100644
index 00000000..38dd8b8b
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_lineinfile.ps1
@@ -0,0 +1,450 @@
+#!powershell
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+#Requires -Module Ansible.ModuleUtils.Backup
+
+function WriteLines($outlines, $path, $linesep, $encodingobj, $validate, $check_mode) {
+ Try {
+ $temppath = [System.IO.Path]::GetTempFileName();
+ }
+ Catch {
+ Fail-Json @{} "Cannot create temporary file! ($($_.Exception.Message))";
+ }
+ $joined = $outlines -join $linesep;
+ [System.IO.File]::WriteAllText($temppath, $joined, $encodingobj);
+
+ If ($validate) {
+
+ If (-not ($validate -like "*%s*")) {
+ Fail-Json @{} "validate must contain %s: $validate";
+ }
+
+ $validate = $validate.Replace("%s", $temppath);
+
+ $parts = [System.Collections.ArrayList] $validate.Split(" ");
+ $cmdname = $parts[0];
+
+ $cmdargs = $validate.Substring($cmdname.Length + 1);
+
+ $process = [Diagnostics.Process]::Start($cmdname, $cmdargs);
+ $process.WaitForExit();
+
+ If ($process.ExitCode -ne 0) {
+ [string] $output = $process.StandardOutput.ReadToEnd();
+ [string] $error = $process.StandardError.ReadToEnd();
+ Remove-Item $temppath -force;
+ Fail-Json @{} "failed to validate $cmdname $cmdargs with error: $output $error";
+ }
+
+ }
+
+ # Commit changes to the path
+ $cleanpath = $path.Replace("/", "\");
+ Try {
+ Copy-Item -Path $temppath -Destination $cleanpath -Force -WhatIf:$check_mode;
+ }
+ Catch {
+ Fail-Json @{} "Cannot write to: $cleanpath ($($_.Exception.Message))";
+ }
+
+ Try {
+ Remove-Item -Path $temppath -Force -WhatIf:$check_mode;
+ }
+ Catch {
+ Fail-Json @{} "Cannot remove temporary file: $temppath ($($_.Exception.Message))";
+ }
+
+ return $joined;
+
+}
+
+
+# Implement the functionality for state == 'present'
+function Present($path, $regex, $line, $insertafter, $insertbefore, $create, $backup, $backrefs, $validate, $encodingobj, $linesep, $check_mode, $diff_support) {
+
+ # Note that we have to clean up the path because ansible wants to treat / and \ as
+ # interchangeable in windows pathnames, but .NET framework internals do not support that.
+ $cleanpath = $path.Replace("/", "\");
+
+ # Check if path exists. If it does not exist, either create it if create == "yes"
+ # was specified or fail with a reasonable error message.
+ If (-not (Test-Path -LiteralPath $path)) {
+ If (-not $create) {
+ Fail-Json @{} "Path $path does not exist !";
+ }
+ # Create new empty file, using the specified encoding to write correct BOM
+ [System.IO.File]::WriteAllLines($cleanpath, "", $encodingobj);
+ }
+
+ # Initialize result information
+ $result = @{
+ backup = "";
+ changed = $false;
+ msg = "";
+ }
+
+ # Read the dest file lines using the indicated encoding into a mutable ArrayList.
+ $before = [System.IO.File]::ReadAllLines($cleanpath, $encodingobj)
+ If ($null -eq $before) {
+ $lines = New-Object System.Collections.ArrayList;
+ }
+ Else {
+ $lines = [System.Collections.ArrayList] $before;
+ }
+
+ if ($diff_support) {
+ $result.diff = @{
+ before = $before -join $linesep;
+ }
+ }
+
+ # Compile the regex specified, if provided
+ $mre = $null;
+ If ($regex) {
+ $mre = New-Object Regex $regex, 'Compiled';
+ }
+
+ # Compile the regex for insertafter or insertbefore, if provided
+ $insre = $null;
+ If ($insertafter -and $insertafter -ne "BOF" -and $insertafter -ne "EOF") {
+ $insre = New-Object Regex $insertafter, 'Compiled';
+ }
+ ElseIf ($insertbefore -and $insertbefore -ne "BOF") {
+ $insre = New-Object Regex $insertbefore, 'Compiled';
+ }
+
+ # index[0] is the line num where regex has been found
+ # index[1] is the line num where insertafter/insertbefore has been found
+ $index = -1, -1;
+ $lineno = 0;
+
+ # The latest match object and matched line
+ $matched_line = "";
+
+ # Iterate through the lines in the file looking for matches
+ Foreach ($cur_line in $lines) {
+ If ($regex) {
+ $m = $mre.Match($cur_line);
+ $match_found = $m.Success;
+ If ($match_found) {
+ $matched_line = $cur_line;
+ }
+ }
+ Else {
+ $match_found = $line -ceq $cur_line;
+ }
+ If ($match_found) {
+ $index[0] = $lineno;
+ }
+ ElseIf ($insre -and $insre.Match($cur_line).Success) {
+ If ($insertafter) {
+ $index[1] = $lineno + 1;
+ }
+ If ($insertbefore) {
+ $index[1] = $lineno;
+ }
+ }
+ $lineno = $lineno + 1;
+ }
+
+ If ($index[0] -ne -1) {
+ If ($backrefs) {
+ $new_line = [regex]::Replace($matched_line, $regex, $line);
+ }
+ Else {
+ $new_line = $line;
+ }
+ If ($lines[$index[0]] -cne $new_line) {
+ $lines[$index[0]] = $new_line;
+ $result.changed = $true;
+ $result.msg = "line replaced";
+ }
+ }
+ ElseIf ($backrefs) {
+ # No matches - no-op
+ }
+ ElseIf ($insertbefore -eq "BOF" -or $insertafter -eq "BOF") {
+ $lines.Insert(0, $line);
+ $result.changed = $true;
+ $result.msg = "line added";
+ }
+ ElseIf ($insertafter -eq "EOF" -or $index[1] -eq -1) {
+ $lines.Add($line) > $null;
+ $result.changed = $true;
+ $result.msg = "line added";
+ }
+ Else {
+ $lines.Insert($index[1], $line);
+ $result.changed = $true;
+ $result.msg = "line added";
+ }
+
+ # Write changes to the path if changes were made
+ If ($result.changed) {
+
+ # Write backup file if backup == "yes"
+ If ($backup) {
+ $result.backup_file = Backup-File -path $path -WhatIf:$check_mode
+ # Ensure backward compatibility (deprecate in future)
+ $result.backup = $result.backup_file
+ }
+
+ $writelines_params = @{
+ outlines = $lines
+ path = $path
+ linesep = $linesep
+ encodingobj = $encodingobj
+ validate = $validate
+ check_mode = $check_mode
+ }
+ $after = WriteLines @writelines_params;
+
+ if ($diff_support) {
+ $result.diff.after = $after;
+ }
+ }
+
+ $result.encoding = $encodingobj.WebName;
+
+ Exit-Json $result;
+}
+
+
+# Implement the functionality for state == 'absent'
+function Absent($path, $regex, $line, $backup, $validate, $encodingobj, $linesep, $check_mode, $diff_support) {
+
+ # Check if path exists. If it does not exist, fail with a reasonable error message.
+ If (-not (Test-Path -LiteralPath $path)) {
+ Fail-Json @{} "Path $path does not exist !";
+ }
+
+ # Initialize result information
+ $result = @{
+ backup = "";
+ changed = $false;
+ msg = "";
+ }
+
+ # Read the dest file lines using the indicated encoding into a mutable ArrayList. Note
+ # that we have to clean up the path because ansible wants to treat / and \ as
+ # interchangeable in windows pathnames, but .NET framework internals do not support that.
+ $cleanpath = $path.Replace("/", "\");
+ $before = [System.IO.File]::ReadAllLines($cleanpath, $encodingobj);
+ If ($null -eq $before) {
+ $lines = New-Object System.Collections.ArrayList;
+ }
+ Else {
+ $lines = [System.Collections.ArrayList] $before;
+ }
+
+ if ($diff_support) {
+ $result.diff = @{
+ before = $before -join $linesep;
+ }
+ }
+
+ # Compile the regex specified, if provided
+ $cre = $null;
+ If ($regex) {
+ $cre = New-Object Regex $regex, 'Compiled';
+ }
+
+ $found = New-Object System.Collections.ArrayList;
+ $left = New-Object System.Collections.ArrayList;
+
+ Foreach ($cur_line in $lines) {
+ If ($regex) {
+ $m = $cre.Match($cur_line);
+ $match_found = $m.Success;
+ }
+ Else {
+ $match_found = $line -ceq $cur_line;
+ }
+ If ($match_found) {
+ $found.Add($cur_line) > $null;
+ $result.changed = $true;
+ }
+ Else {
+ $left.Add($cur_line) > $null;
+ }
+ }
+
+ # Write changes to the path if changes were made
+ If ($result.changed) {
+
+ # Write backup file if backup == "yes"
+ If ($backup) {
+ $result.backup_file = Backup-File -path $path -WhatIf:$check_mode
+ # Ensure backward compatibility (deprecate in future)
+ $result.backup = $result.backup_file
+ }
+
+ $writelines_params = @{
+ outlines = $left
+ path = $path
+ linesep = $linesep
+ encodingobj = $encodingobj
+ validate = $validate
+ check_mode = $check_mode
+ }
+ $after = WriteLines @writelines_params;
+
+ if ($diff_support) {
+ $result.diff.after = $after;
+ }
+ }
+
+ $result.encoding = $encodingobj.WebName;
+ $result.found = $found.Count;
+ $result.msg = "$($found.Count) line(s) removed";
+
+ Exit-Json $result;
+}
+
+
+# Parse the parameters file dropped by the Ansible machinery
+$params = Parse-Args $args -supports_check_mode $true;
+$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false;
+$diff_support = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false;
+
+# Initialize defaults for input parameters.
+$path = Get-AnsibleParam -obj $params -name "path" -type "path" -failifempty $true -aliases "dest","destfile","name";
+$regex = Get-AnsibleParam -obj $params -name "regex" -type "str" -aliases "regexp";
+$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent";
+$line = Get-AnsibleParam -obj $params -name "line" -type "str";
+$backrefs = Get-AnsibleParam -obj $params -name "backrefs" -type "bool" -default $false;
+$insertafter = Get-AnsibleParam -obj $params -name "insertafter" -type "str";
+$insertbefore = Get-AnsibleParam -obj $params -name "insertbefore" -type "str";
+$create = Get-AnsibleParam -obj $params -name "create" -type "bool" -default $false;
+$backup = Get-AnsibleParam -obj $params -name "backup" -type "bool" -default $false;
+$validate = Get-AnsibleParam -obj $params -name "validate" -type "str";
+$encoding = Get-AnsibleParam -obj $params -name "encoding" -type "str" -default "auto";
+$newline = Get-AnsibleParam -obj $params -name "newline" -type "str" -default "windows" -validateset "unix","windows";
+
+# Fail if the path is not a file
+If (Test-Path -LiteralPath $path -PathType "container") {
+ Fail-Json @{} "Path $path is a directory";
+}
+
+# Default to windows line separator - probably most common
+$linesep = "`r`n"
+If ($newline -eq "unix") {
+ $linesep = "`n";
+}
+
+# Figure out the proper encoding to use for reading / writing the target file.
+
+# The default encoding is UTF-8 without BOM
+$encodingobj = [System.Text.UTF8Encoding] $false;
+
+# If an explicit encoding is specified, use that instead
+If ($encoding -ne "auto") {
+ $encodingobj = [System.Text.Encoding]::GetEncoding($encoding);
+}
+
+# Otherwise see if we can determine the current encoding of the target file.
+# If the file doesn't exist yet (create == 'yes') we use the default or
+# explicitly specified encoding set above.
+ElseIf (Test-Path -LiteralPath $path) {
+
+ # Get a sorted list of encodings with preambles, longest first
+ $max_preamble_len = 0;
+ $sortedlist = New-Object System.Collections.SortedList;
+ Foreach ($encodinginfo in [System.Text.Encoding]::GetEncodings()) {
+ $encoding = $encodinginfo.GetEncoding();
+ $plen = $encoding.GetPreamble().Length;
+ If ($plen -gt $max_preamble_len) {
+ $max_preamble_len = $plen;
+ }
+ If ($plen -gt 0) {
+ $sortedlist.Add(-($plen * 1000000 + $encoding.CodePage), $encoding) > $null;
+ }
+ }
+
+ # Get the first N bytes from the file, where N is the max preamble length we saw
+ [Byte[]]$bom = Get-Content -Encoding Byte -ReadCount $max_preamble_len -TotalCount $max_preamble_len -LiteralPath $path;
+
+ # Iterate through the sorted encodings, looking for a full match.
+ $found = $false;
+ Foreach ($encoding in $sortedlist.GetValueList()) {
+ $preamble = $encoding.GetPreamble();
+ If ($preamble -and $bom) {
+ Foreach ($i in 0..($preamble.Length - 1)) {
+ If ($i -ge $bom.Length) {
+ break;
+ }
+ If ($preamble[$i] -ne $bom[$i]) {
+ break;
+ }
+ ElseIf ($i + 1 -eq $preamble.Length) {
+ $encodingobj = $encoding;
+ $found = $true;
+ }
+ }
+ If ($found) {
+ break;
+ }
+ }
+ }
+}
+
+
+# Main dispatch - based on the value of 'state', perform argument validation and
+# call the appropriate handler function.
+If ($state -eq "present") {
+
+ If ($backrefs -and -not $regex) {
+ Fail-Json @{} "regexp= is required with backrefs=true";
+ }
+
+ If (-not $line) {
+ Fail-Json @{} "line= is required with state=present";
+ }
+
+ If ($insertbefore -and $insertafter) {
+ Add-Warning $result "Both insertbefore and insertafter parameters found, ignoring `"insertafter=$insertafter`""
+ }
+
+ If (-not $insertbefore -and -not $insertafter) {
+ $insertafter = "EOF";
+ }
+
+ $present_params = @{
+ path = $path
+ regex = $regex
+ line = $line
+ insertafter = $insertafter
+ insertbefore = $insertbefore
+ create = $create
+ backup = $backup
+ backrefs = $backrefs
+ validate = $validate
+ encodingobj = $encodingobj
+ linesep = $linesep
+ check_mode = $check_mode
+ diff_support = $diff_support
+ }
+ Present @present_params;
+
+}
+ElseIf ($state -eq "absent") {
+
+ If (-not $regex -and -not $line) {
+ Fail-Json @{} "one of line= or regexp= is required with state=absent";
+ }
+
+ $absent_params = @{
+ path = $path
+ regex = $regex
+ line = $line
+ backup = $backup
+ validate = $validate
+ encodingobj = $encodingobj
+ linesep = $linesep
+ check_mode = $check_mode
+ diff_support = $diff_support
+ }
+ Absent @absent_params;
+}
diff --git a/test/support/windows-integration/plugins/modules/win_lineinfile.py b/test/support/windows-integration/plugins/modules/win_lineinfile.py
new file mode 100644
index 00000000..f4fb7f5a
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_lineinfile.py
@@ -0,0 +1,180 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_lineinfile
+short_description: Ensure a particular line is in a file, or replace an existing line using a back-referenced regular expression
+description:
+ - This module will search a file for a line, and ensure that it is present or absent.
+ - This is primarily useful when you want to change a single line in a file only.
+version_added: "2.0"
+options:
+ path:
+ description:
+ - The path of the file to modify.
+ - Note that the Windows path delimiter C(\) must be escaped as C(\\) when the line is double quoted.
+ - Before Ansible 2.3 this option was only usable as I(dest), I(destfile) and I(name).
+ type: path
+ required: yes
+ aliases: [ dest, destfile, name ]
+ backup:
+ description:
+ - Determine whether a backup should be created.
+ - When set to C(yes), create a backup file including the timestamp information
+ so you can get the original file back if you somehow clobbered it incorrectly.
+ type: bool
+ default: no
+ regex:
+ description:
+ - The regular expression to look for in every line of the file. For C(state=present), the pattern to replace if found; only the last line found
+ will be replaced. For C(state=absent), the pattern of the line to remove. Uses .NET compatible regular expressions;
+ see U(https://msdn.microsoft.com/en-us/library/hs600312%28v=vs.110%29.aspx).
+ aliases: [ "regexp" ]
+ state:
+ description:
+ - Whether the line should be there or not.
+ type: str
+ choices: [ absent, present ]
+ default: present
+ line:
+ description:
+ - Required for C(state=present). The line to insert/replace into the file. If C(backrefs) is set, may contain backreferences that will get
+ expanded with the C(regexp) capture groups if the regexp matches.
+ - Be aware that the line is processed first on the controller and thus is dependent on yaml quoting rules. Any double quoted line
+ will have control characters, such as '\r\n', expanded. To print such characters literally, use single or no quotes.
+ type: str
+ backrefs:
+ description:
+ - Used with C(state=present). If set, line can contain backreferences (both positional and named) that will get populated if the C(regexp)
+ matches. This flag changes the operation of the module slightly; C(insertbefore) and C(insertafter) will be ignored, and if the C(regexp)
+ doesn't match anywhere in the file, the file will be left unchanged.
+ - If the C(regexp) does match, the last matching line will be replaced by the expanded line parameter.
+ type: bool
+ default: no
+ insertafter:
+ description:
+ - Used with C(state=present). If specified, the line will be inserted after the last match of specified regular expression. A special value is
+ available; C(EOF) for inserting the line at the end of the file.
+ - If specified regular expression has no matches, EOF will be used instead. May not be used with C(backrefs).
+ type: str
+ choices: [ EOF, '*regex*' ]
+ default: EOF
+ insertbefore:
+ description:
+ - Used with C(state=present). If specified, the line will be inserted before the last match of specified regular expression. A value is available;
+ C(BOF) for inserting the line at the beginning of the file.
+ - If specified regular expression has no matches, the line will be inserted at the end of the file. May not be used with C(backrefs).
+ type: str
+ choices: [ BOF, '*regex*' ]
+ create:
+ description:
+ - Used with C(state=present). If specified, the file will be created if it does not already exist. By default it will fail if the file is missing.
+ type: bool
+ default: no
+ validate:
+ description:
+ - Validation to run before copying into place. Use %s in the command to indicate the current file to validate.
+ - The command is passed securely so shell features like expansion and pipes won't work.
+ type: str
+ encoding:
+ description:
+ - Specifies the encoding of the source text file to operate on (and thus what the output encoding will be). The default of C(auto) will cause
+ the module to auto-detect the encoding of the source file and ensure that the modified file is written with the same encoding.
+ - An explicit encoding can be passed as a string that is a valid value to pass to the .NET framework System.Text.Encoding.GetEncoding() method -
+ see U(https://msdn.microsoft.com/en-us/library/system.text.encoding%28v=vs.110%29.aspx).
+ - This is mostly useful with C(create=yes) if you want to create a new file with a specific encoding. If C(create=yes) is specified without a
+ specific encoding, the default encoding (UTF-8, no BOM) will be used.
+ type: str
+ default: auto
+ newline:
+ description:
+ - Specifies the line separator style to use for the modified file. This defaults to the windows line separator (C(\r\n)). Note that the indicated
+ line separator will be used for file output regardless of the original line separator that appears in the input file.
+ type: str
+ choices: [ unix, windows ]
+ default: windows
+notes:
+ - As of Ansible 2.3, the I(dest) option has been changed to I(path) as default, but I(dest) still works as well.
+seealso:
+- module: assemble
+- module: lineinfile
+author:
+- Brian Lloyd (@brianlloyd)
+'''
+
+EXAMPLES = r'''
+# Before Ansible 2.3, option 'dest', 'destfile' or 'name' was used instead of 'path'
+- name: Insert path without converting \r\n
+ win_lineinfile:
+ path: c:\file.txt
+ line: c:\return\new
+
+- win_lineinfile:
+ path: C:\Temp\example.conf
+ regex: '^name='
+ line: 'name=JohnDoe'
+
+- win_lineinfile:
+ path: C:\Temp\example.conf
+ regex: '^name='
+ state: absent
+
+- win_lineinfile:
+ path: C:\Temp\example.conf
+ regex: '^127\.0\.0\.1'
+ line: '127.0.0.1 localhost'
+
+- win_lineinfile:
+ path: C:\Temp\httpd.conf
+ regex: '^Listen '
+ insertafter: '^#Listen '
+ line: Listen 8080
+
+- win_lineinfile:
+ path: C:\Temp\services
+ regex: '^# port for http'
+ insertbefore: '^www.*80/tcp'
+ line: '# port for http by default'
+
+- name: Create file if it doesn't exist with a specific encoding
+ win_lineinfile:
+ path: C:\Temp\utf16.txt
+ create: yes
+ encoding: utf-16
+ line: This is a utf-16 encoded file
+
+- name: Add a line to a file and ensure the resulting file uses unix line separators
+ win_lineinfile:
+ path: C:\Temp\testfile.txt
+ line: Line added to file
+ newline: unix
+
+- name: Update a line using backrefs
+ win_lineinfile:
+ path: C:\Temp\example.conf
+ backrefs: yes
+ regex: '(^name=)'
+ line: '$1JohnDoe'
+'''
+
+RETURN = r'''
+backup:
+ description:
+ - Name of the backup file that was created.
+ - This is now deprecated, use C(backup_file) instead.
+ returned: if backup=yes
+ type: str
+ sample: C:\Path\To\File.txt.11540.20150212-220915.bak
+backup_file:
+ description: Name of the backup file that was created.
+ returned: if backup=yes
+ type: str
+ sample: C:\Path\To\File.txt.11540.20150212-220915.bak
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_path.ps1 b/test/support/windows-integration/plugins/modules/win_path.ps1
new file mode 100644
index 00000000..04eb41a3
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_path.ps1
@@ -0,0 +1,145 @@
+#!powershell
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+
+Set-StrictMode -Version 2
+$ErrorActionPreference = "Stop"
+
+$system_path = "System\CurrentControlSet\Control\Session Manager\Environment"
+$user_path = "Environment"
+
+# list/arraylist methods don't allow IEqualityComparer override for case/backslash/quote-insensitivity, roll our own search
+Function Get-IndexOfPathElement ($list, [string]$value) {
+ $idx = 0
+ $value = $value.Trim('"').Trim('\')
+ ForEach($el in $list) {
+ If ([string]$el.Trim('"').Trim('\') -ieq $value) {
+ return $idx
+ }
+
+ $idx++
+ }
+
+ return -1
+}
+
+# alters list in place, returns true if at least one element was added
+Function Add-Elements ($existing_elements, $elements_to_add) {
+ $last_idx = -1
+ $changed = $false
+
+ ForEach($el in $elements_to_add) {
+ $idx = Get-IndexOfPathElement $existing_elements $el
+
+ # add missing elements at the end
+ If ($idx -eq -1) {
+ $last_idx = $existing_elements.Add($el)
+ $changed = $true
+ }
+ ElseIf ($idx -lt $last_idx) {
+ $existing_elements.RemoveAt($idx) | Out-Null
+ $existing_elements.Add($el) | Out-Null
+ $last_idx = $existing_elements.Count - 1
+ $changed = $true
+ }
+ Else {
+ $last_idx = $idx
+ }
+ }
+
+ return $changed
+}
+
+# alters list in place, returns true if at least one element was removed
+Function Remove-Elements ($existing_elements, $elements_to_remove) {
+ $count = $existing_elements.Count
+
+ ForEach($el in $elements_to_remove) {
+ $idx = Get-IndexOfPathElement $existing_elements $el
+ $result.removed_idx = $idx
+ If ($idx -gt -1) {
+ $existing_elements.RemoveAt($idx)
+ }
+ }
+
+ return $count -ne $existing_elements.Count
+}
+
+# PS registry provider doesn't allow access to unexpanded REG_EXPAND_SZ; fall back to .NET
+Function Get-RawPathVar ($scope) {
+ If ($scope -eq "user") {
+ $env_key = [Microsoft.Win32.Registry]::CurrentUser.OpenSubKey($user_path)
+ }
+ ElseIf ($scope -eq "machine") {
+ $env_key = [Microsoft.Win32.Registry]::LocalMachine.OpenSubKey($system_path)
+ }
+
+ return $env_key.GetValue($var_name, "", [Microsoft.Win32.RegistryValueOptions]::DoNotExpandEnvironmentNames)
+}
+
+Function Set-RawPathVar($path_value, $scope) {
+ If ($scope -eq "user") {
+ $var_path = "HKCU:\" + $user_path
+ }
+ ElseIf ($scope -eq "machine") {
+ $var_path = "HKLM:\" + $system_path
+ }
+
+ Set-ItemProperty $var_path -Name $var_name -Value $path_value -Type ExpandString | Out-Null
+
+ return $path_value
+}
+
+$parsed_args = Parse-Args $args -supports_check_mode $true
+
+$result = @{changed=$false}
+
+$var_name = Get-AnsibleParam $parsed_args "name" -Default "PATH"
+$elements = Get-AnsibleParam $parsed_args "elements" -FailIfEmpty $result
+$state = Get-AnsibleParam $parsed_args "state" -Default "present" -ValidateSet "present","absent"
+$scope = Get-AnsibleParam $parsed_args "scope" -Default "machine" -ValidateSet "machine","user"
+
+$check_mode = Get-AnsibleParam $parsed_args "_ansible_check_mode" -Default $false
+
+If ($elements -is [string]) {
+ $elements = @($elements)
+}
+
+If ($elements -isnot [Array]) {
+ Fail-Json $result "elements must be a string or list of path strings"
+}
+
+$current_value = Get-RawPathVar $scope
+$result.path_value = $current_value
+
+# TODO: test case-canonicalization on wacky unicode values (eg turkish i)
+# TODO: detect and warn/fail on unparseable path? (eg, unbalanced quotes, invalid path chars)
+# TODO: detect and warn/fail if system path and Powershell isn't on it?
+
+$existing_elements = New-Object System.Collections.ArrayList
+
+# split on semicolons, accounting for quoted values with embedded semicolons (which may or may not be wrapped in whitespace)
+$pathsplit_re = [regex] '((?<q>\s*"[^"]+"\s*)|(?<q>[^;]+))(;$|$|;)'
+
+ForEach ($m in $pathsplit_re.Matches($current_value)) {
+ $existing_elements.Add($m.Groups['q'].Value) | Out-Null
+}
+
+If ($state -eq "absent") {
+ $result.changed = Remove-Elements $existing_elements $elements
+}
+ElseIf ($state -eq "present") {
+ $result.changed = Add-Elements $existing_elements $elements
+}
+
+# calculate the new path value from the existing elements
+$path_value = [String]::Join(";", $existing_elements.ToArray())
+$result.path_value = $path_value
+
+If ($result.changed -and -not $check_mode) {
+ Set-RawPathVar $path_value $scope | Out-Null
+}
+
+Exit-Json $result
diff --git a/test/support/windows-integration/plugins/modules/win_path.py b/test/support/windows-integration/plugins/modules/win_path.py
new file mode 100644
index 00000000..6404504f
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_path.py
@@ -0,0 +1,79 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2016, Red Hat | Ansible
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# This is a windows documentation stub. Actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_path
+version_added: "2.3"
+short_description: Manage Windows path environment variables
+description:
+ - Allows element-based ordering, addition, and removal of Windows path environment variables.
+options:
+ name:
+ description:
+ - Target path environment variable name.
+ type: str
+ default: PATH
+ elements:
+ description:
+ - A single path element, or a list of path elements (ie, directories) to add or remove.
+ - When multiple elements are included in the list (and C(state) is C(present)), the elements are guaranteed to appear in the same relative order
+ in the resultant path value.
+ - Variable expansions (eg, C(%VARNAME%)) are allowed, and are stored unexpanded in the target path element.
+ - Any existing path elements not mentioned in C(elements) are always preserved in their current order.
+ - New path elements are appended to the path, and existing path elements may be moved closer to the end to satisfy the requested ordering.
+ - Paths are compared in a case-insensitive fashion, and trailing backslashes are ignored for comparison purposes. However, note that trailing
+ backslashes in YAML require quotes.
+ type: list
+ required: yes
+ state:
+ description:
+ - Whether the path elements specified in C(elements) should be present or absent.
+ type: str
+ choices: [ absent, present ]
+ scope:
+ description:
+ - The level at which the environment variable specified by C(name) should be managed (either for the current user or global machine scope).
+ type: str
+ choices: [ machine, user ]
+ default: machine
+notes:
+ - This module is for modifying individual elements of path-like
+ environment variables. For general-purpose management of other
+ environment vars, use the M(win_environment) module.
+ - This module does not broadcast change events.
+ This means that the minority of windows applications which can have
+ their environment changed without restarting will not be notified and
+ therefore will need restarting to pick up new environment settings.
+ - User level environment variables will require an interactive user to
+ log out and in again before they become available.
+seealso:
+- module: win_environment
+author:
+- Matt Davis (@nitzmahone)
+'''
+
+EXAMPLES = r'''
+- name: Ensure that system32 and Powershell are present on the global system path, and in the specified order
+ win_path:
+ elements:
+ - '%SystemRoot%\system32'
+ - '%SystemRoot%\system32\WindowsPowerShell\v1.0'
+
+- name: Ensure that C:\Program Files\MyJavaThing is not on the current user's CLASSPATH
+ win_path:
+ name: CLASSPATH
+ elements: C:\Program Files\MyJavaThing
+ scope: user
+ state: absent
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_ping.ps1 b/test/support/windows-integration/plugins/modules/win_ping.ps1
new file mode 100644
index 00000000..c848b912
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_ping.ps1
@@ -0,0 +1,21 @@
+#!powershell
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+
+$spec = @{
+ options = @{
+ data = @{ type = "str"; default = "pong" }
+ }
+ supports_check_mode = $true
+}
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec)
+$data = $module.Params.data
+
+if ($data -eq "crash") {
+ throw "boom"
+}
+
+$module.Result.ping = $data
+$module.ExitJson()
diff --git a/test/support/windows-integration/plugins/modules/win_ping.py b/test/support/windows-integration/plugins/modules/win_ping.py
new file mode 100644
index 00000000..6d35f379
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_ping.py
@@ -0,0 +1,55 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>, and others
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub. actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_ping
+version_added: "1.7"
+short_description: A windows version of the classic ping module
+description:
+ - Checks management connectivity of a windows host.
+ - This is NOT ICMP ping, this is just a trivial test module.
+ - For non-Windows targets, use the M(ping) module instead.
+ - For Network targets, use the M(net_ping) module instead.
+options:
+ data:
+ description:
+ - Alternate data to return instead of 'pong'.
+ - If this parameter is set to C(crash), the module will cause an exception.
+ type: str
+ default: pong
+seealso:
+- module: ping
+author:
+- Chris Church (@cchurch)
+'''
+
+EXAMPLES = r'''
+# Test connectivity to a windows host
+# ansible winserver -m win_ping
+
+- name: Example from an Ansible Playbook
+ win_ping:
+
+- name: Induce an exception to see what happens
+ win_ping:
+ data: crash
+'''
+
+RETURN = r'''
+ping:
+ description: Value provided with the data parameter.
+ returned: success
+ type: str
+ sample: pong
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_psexec.ps1 b/test/support/windows-integration/plugins/modules/win_psexec.ps1
new file mode 100644
index 00000000..04a51270
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_psexec.ps1
@@ -0,0 +1,152 @@
+#!powershell
+
+# Copyright: (c) 2017, Dag Wieers (@dagwieers) <dag@wieers.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+#Requires -Module Ansible.ModuleUtils.ArgvParser
+#Requires -Module Ansible.ModuleUtils.CommandUtil
+
+# See also: https://technet.microsoft.com/en-us/sysinternals/pxexec.aspx
+
+$spec = @{
+ options = @{
+ command = @{ type='str'; required=$true }
+ executable = @{ type='path'; default='psexec.exe' }
+ hostnames = @{ type='list' }
+ username = @{ type='str' }
+ password = @{ type='str'; no_log=$true }
+ chdir = @{ type='path' }
+ wait = @{ type='bool'; default=$true }
+ nobanner = @{ type='bool'; default=$false }
+ noprofile = @{ type='bool'; default=$false }
+ elevated = @{ type='bool'; default=$false }
+ limited = @{ type='bool'; default=$false }
+ system = @{ type='bool'; default=$false }
+ interactive = @{ type='bool'; default=$false }
+ session = @{ type='int' }
+ priority = @{ type='str'; choices=@( 'background', 'low', 'belownormal', 'abovenormal', 'high', 'realtime' ) }
+ timeout = @{ type='int' }
+ }
+}
+
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec)
+
+$command = $module.Params.command
+$executable = $module.Params.executable
+$hostnames = $module.Params.hostnames
+$username = $module.Params.username
+$password = $module.Params.password
+$chdir = $module.Params.chdir
+$wait = $module.Params.wait
+$nobanner = $module.Params.nobanner
+$noprofile = $module.Params.noprofile
+$elevated = $module.Params.elevated
+$limited = $module.Params.limited
+$system = $module.Params.system
+$interactive = $module.Params.interactive
+$session = $module.Params.session
+$priority = $module.Params.Priority
+$timeout = $module.Params.timeout
+
+$module.Result.changed = $true
+
+If (-Not (Get-Command $executable -ErrorAction SilentlyContinue)) {
+ $module.FailJson("Executable '$executable' was not found.")
+}
+
+$arguments = [System.Collections.Generic.List`1[String]]@($executable)
+
+If ($nobanner -eq $true) {
+ $arguments.Add("-nobanner")
+}
+
+# Support running on local system if no hostname is specified
+If ($hostnames) {
+ $hostname_argument = ($hostnames | sort -Unique) -join ','
+ $arguments.Add("\\$hostname_argument")
+}
+
+# Username is optional
+If ($null -ne $username) {
+ $arguments.Add("-u")
+ $arguments.Add($username)
+}
+
+# Password is optional
+If ($null -ne $password) {
+ $arguments.Add("-p")
+ $arguments.Add($password)
+}
+
+If ($null -ne $chdir) {
+ $arguments.Add("-w")
+ $arguments.Add($chdir)
+}
+
+If ($wait -eq $false) {
+ $arguments.Add("-d")
+}
+
+If ($noprofile -eq $true) {
+ $arguments.Add("-e")
+}
+
+If ($elevated -eq $true) {
+ $arguments.Add("-h")
+}
+
+If ($system -eq $true) {
+ $arguments.Add("-s")
+}
+
+If ($interactive -eq $true) {
+ $arguments.Add("-i")
+ If ($null -ne $session) {
+ $arguments.Add($session)
+ }
+}
+
+If ($limited -eq $true) {
+ $arguments.Add("-l")
+}
+
+If ($null -ne $priority) {
+ $arguments.Add("-$priority")
+}
+
+If ($null -ne $timeout) {
+ $arguments.Add("-n")
+ $arguments.Add($timeout)
+}
+
+$arguments.Add("-accepteula")
+
+$argument_string = Argv-ToString -arguments $arguments
+
+# Add the command at the end of the argument string, we don't want to escape
+# that as psexec doesn't expect it to be one arg
+$argument_string += " $command"
+
+$start_datetime = [DateTime]::UtcNow
+$module.Result.psexec_command = $argument_string
+
+$command_result = Run-Command -command $argument_string
+
+$end_datetime = [DateTime]::UtcNow
+
+$module.Result.stdout = $command_result.stdout
+$module.Result.stderr = $command_result.stderr
+
+If ($wait -eq $true) {
+ $module.Result.rc = $command_result.rc
+} else {
+ $module.Result.rc = 0
+ $module.Result.pid = $command_result.rc
+}
+
+$module.Result.start = $start_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff")
+$module.Result.end = $end_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff")
+$module.Result.delta = $($end_datetime - $start_datetime).ToString("h\:mm\:ss\.ffffff")
+
+$module.ExitJson()
diff --git a/test/support/windows-integration/plugins/modules/win_psexec.py b/test/support/windows-integration/plugins/modules/win_psexec.py
new file mode 100644
index 00000000..c3fc37e4
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_psexec.py
@@ -0,0 +1,172 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: 2017, Dag Wieers (@dagwieers) <dag@wieers.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_psexec
+version_added: '2.3'
+short_description: Runs commands (remotely) as another (privileged) user
+description:
+- Run commands (remotely) through the PsExec service.
+- Run commands as another (domain) user (with elevated privileges).
+requirements:
+- Microsoft PsExec
+options:
+ command:
+ description:
+ - The command line to run through PsExec (limited to 260 characters).
+ type: str
+ required: yes
+ executable:
+ description:
+ - The location of the PsExec utility (in case it is not located in your PATH).
+ type: path
+ default: psexec.exe
+ hostnames:
+ description:
+ - The hostnames to run the command.
+ - If not provided, the command is run locally.
+ type: list
+ username:
+ description:
+ - The (remote) user to run the command as.
+ - If not provided, the current user is used.
+ type: str
+ password:
+ description:
+ - The password for the (remote) user to run the command as.
+ - This is mandatory in order authenticate yourself.
+ type: str
+ chdir:
+ description:
+ - Run the command from this (remote) directory.
+ type: path
+ nobanner:
+ description:
+ - Do not display the startup banner and copyright message.
+ - This only works for specific versions of the PsExec binary.
+ type: bool
+ default: no
+ version_added: '2.4'
+ noprofile:
+ description:
+ - Run the command without loading the account's profile.
+ type: bool
+ default: no
+ elevated:
+ description:
+ - Run the command with elevated privileges.
+ type: bool
+ default: no
+ interactive:
+ description:
+ - Run the program so that it interacts with the desktop on the remote system.
+ type: bool
+ default: no
+ session:
+ description:
+ - Specifies the session ID to use.
+ - This parameter works in conjunction with I(interactive).
+ - It has no effect when I(interactive) is set to C(no).
+ type: int
+ version_added: '2.7'
+ limited:
+ description:
+ - Run the command as limited user (strips the Administrators group and allows only privileges assigned to the Users group).
+ type: bool
+ default: no
+ system:
+ description:
+ - Run the remote command in the System account.
+ type: bool
+ default: no
+ priority:
+ description:
+ - Used to run the command at a different priority.
+ choices: [ abovenormal, background, belownormal, high, low, realtime ]
+ timeout:
+ description:
+ - The connection timeout in seconds
+ type: int
+ wait:
+ description:
+ - Wait for the application to terminate.
+ - Only use for non-interactive applications.
+ type: bool
+ default: yes
+notes:
+- More information related to Microsoft PsExec is available from
+ U(https://technet.microsoft.com/en-us/sysinternals/bb897553.aspx)
+seealso:
+- module: psexec
+- module: raw
+- module: win_command
+- module: win_shell
+author:
+- Dag Wieers (@dagwieers)
+'''
+
+EXAMPLES = r'''
+- name: Test the PsExec connection to the local system (target node) with your user
+ win_psexec:
+ command: whoami.exe
+
+- name: Run regedit.exe locally (on target node) as SYSTEM and interactively
+ win_psexec:
+ command: regedit.exe
+ interactive: yes
+ system: yes
+
+- name: Run the setup.exe installer on multiple servers using the Domain Administrator
+ win_psexec:
+ command: E:\setup.exe /i /IACCEPTEULA
+ hostnames:
+ - remote_server1
+ - remote_server2
+ username: DOMAIN\Administrator
+ password: some_password
+ priority: high
+
+- name: Run PsExec from custom location C:\Program Files\sysinternals\
+ win_psexec:
+ command: netsh advfirewall set allprofiles state off
+ executable: C:\Program Files\sysinternals\psexec.exe
+ hostnames: [ remote_server ]
+ password: some_password
+ priority: low
+'''
+
+RETURN = r'''
+cmd:
+ description: The complete command line used by the module, including PsExec call and additional options.
+ returned: always
+ type: str
+ sample: psexec.exe -nobanner \\remote_server -u "DOMAIN\Administrator" -p "some_password" -accepteula E:\setup.exe
+pid:
+ description: The PID of the async process created by PsExec.
+ returned: when C(wait=False)
+ type: int
+ sample: 1532
+rc:
+ description: The return code for the command.
+ returned: always
+ type: int
+ sample: 0
+stdout:
+ description: The standard output from the command.
+ returned: always
+ type: str
+ sample: Success.
+stderr:
+ description: The error output from the command.
+ returned: always
+ type: str
+ sample: Error 15 running E:\setup.exe
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_reboot.py b/test/support/windows-integration/plugins/modules/win_reboot.py
new file mode 100644
index 00000000..14318041
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_reboot.py
@@ -0,0 +1,131 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_reboot
+short_description: Reboot a windows machine
+description:
+- Reboot a Windows machine, wait for it to go down, come back up, and respond to commands.
+- For non-Windows targets, use the M(reboot) module instead.
+version_added: '2.1'
+options:
+ pre_reboot_delay:
+ description:
+ - Seconds to wait before reboot. Passed as a parameter to the reboot command.
+ type: int
+ default: 2
+ aliases: [ pre_reboot_delay_sec ]
+ post_reboot_delay:
+ description:
+ - Seconds to wait after the reboot command was successful before attempting to validate the system rebooted successfully.
+ - This is useful if you want wait for something to settle despite your connection already working.
+ type: int
+ default: 0
+ version_added: '2.4'
+ aliases: [ post_reboot_delay_sec ]
+ shutdown_timeout:
+ description:
+ - Maximum seconds to wait for shutdown to occur.
+ - Increase this timeout for very slow hardware, large update applications, etc.
+ - This option has been removed since Ansible 2.5 as the win_reboot behavior has changed.
+ type: int
+ default: 600
+ aliases: [ shutdown_timeout_sec ]
+ reboot_timeout:
+ description:
+ - Maximum seconds to wait for machine to re-appear on the network and respond to a test command.
+ - This timeout is evaluated separately for both reboot verification and test command success so maximum clock time is actually twice this value.
+ type: int
+ default: 600
+ aliases: [ reboot_timeout_sec ]
+ connect_timeout:
+ description:
+ - Maximum seconds to wait for a single successful TCP connection to the WinRM endpoint before trying again.
+ type: int
+ default: 5
+ aliases: [ connect_timeout_sec ]
+ test_command:
+ description:
+ - Command to expect success for to determine the machine is ready for management.
+ type: str
+ default: whoami
+ msg:
+ description:
+ - Message to display to users.
+ type: str
+ default: Reboot initiated by Ansible
+ boot_time_command:
+ description:
+ - Command to run that returns a unique string indicating the last time the system was booted.
+ - Setting this to a command that has different output each time it is run will cause the task to fail.
+ type: str
+ default: '(Get-WmiObject -ClassName Win32_OperatingSystem).LastBootUpTime'
+ version_added: '2.10'
+notes:
+- If a shutdown was already scheduled on the system, C(win_reboot) will abort the scheduled shutdown and enforce its own shutdown.
+- Beware that when C(win_reboot) returns, the Windows system may not have settled yet and some base services could be in limbo.
+ This can result in unexpected behavior. Check the examples for ways to mitigate this.
+- The connection user must have the C(SeRemoteShutdownPrivilege) privilege enabled, see
+ U(https://docs.microsoft.com/en-us/windows/security/threat-protection/security-policy-settings/force-shutdown-from-a-remote-system)
+ for more information.
+seealso:
+- module: reboot
+author:
+- Matt Davis (@nitzmahone)
+'''
+
+EXAMPLES = r'''
+- name: Reboot the machine with all defaults
+ win_reboot:
+
+- name: Reboot a slow machine that might have lots of updates to apply
+ win_reboot:
+ reboot_timeout: 3600
+
+# Install a Windows feature and reboot if necessary
+- name: Install IIS Web-Server
+ win_feature:
+ name: Web-Server
+ register: iis_install
+
+- name: Reboot when Web-Server feature requires it
+ win_reboot:
+ when: iis_install.reboot_required
+
+# One way to ensure the system is reliable, is to set WinRM to a delayed startup
+- name: Ensure WinRM starts when the system has settled and is ready to work reliably
+ win_service:
+ name: WinRM
+ start_mode: delayed
+
+
+# Additionally, you can add a delay before running the next task
+- name: Reboot a machine that takes time to settle after being booted
+ win_reboot:
+ post_reboot_delay: 120
+
+# Or you can make win_reboot validate exactly what you need to work before running the next task
+- name: Validate that the netlogon service has started, before running the next task
+ win_reboot:
+ test_command: 'exit (Get-Service -Name Netlogon).Status -ne "Running"'
+'''
+
+RETURN = r'''
+rebooted:
+ description: True if the machine was rebooted.
+ returned: always
+ type: bool
+ sample: true
+elapsed:
+ description: The number of seconds that elapsed waiting for the system to be rebooted.
+ returned: always
+ type: float
+ sample: 23.2
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_regedit.ps1 b/test/support/windows-integration/plugins/modules/win_regedit.ps1
new file mode 100644
index 00000000..c56b4833
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_regedit.ps1
@@ -0,0 +1,495 @@
+#!powershell
+
+# Copyright: (c) 2015, Adam Keech <akeech@chathamfinancial.com>
+# Copyright: (c) 2015, Josh Ludwig <jludwig@chathamfinancial.com>
+# Copyright: (c) 2017, Jordan Borean <jborean93@gmail.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+#Requires -Module Ansible.ModuleUtils.PrivilegeUtil
+
+$params = Parse-Args -arguments $args -supports_check_mode $true
+$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false
+$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false
+$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP
+
+$path = Get-AnsibleParam -obj $params -name "path" -type "str" -failifempty $true -aliases "key"
+$name = Get-AnsibleParam -obj $params -name "name" -type "str" -aliases "entry","value"
+$data = Get-AnsibleParam -obj $params -name "data"
+$type = Get-AnsibleParam -obj $params -name "type" -type "str" -default "string" -validateset "none","binary","dword","expandstring","multistring","string","qword" -aliases "datatype"
+$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent"
+$delete_key = Get-AnsibleParam -obj $params -name "delete_key" -type "bool" -default $true
+$hive = Get-AnsibleParam -obj $params -name "hive" -type "path"
+
+$result = @{
+ changed = $false
+ data_changed = $false
+ data_type_changed = $false
+}
+
+if ($diff_mode) {
+ $result.diff = @{
+ before = ""
+ after = ""
+ }
+}
+
+$registry_util = @'
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices;
+
+namespace Ansible.WinRegedit
+{
+ internal class NativeMethods
+ {
+ [DllImport("advapi32.dll", CharSet = CharSet.Unicode)]
+ public static extern int RegLoadKeyW(
+ UInt32 hKey,
+ string lpSubKey,
+ string lpFile);
+
+ [DllImport("advapi32.dll", CharSet = CharSet.Unicode)]
+ public static extern int RegUnLoadKeyW(
+ UInt32 hKey,
+ string lpSubKey);
+ }
+
+ public class Win32Exception : System.ComponentModel.Win32Exception
+ {
+ private string _msg;
+ public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { }
+ public Win32Exception(int errorCode, string message) : base(errorCode)
+ {
+ _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode);
+ }
+ public override string Message { get { return _msg; } }
+ public static explicit operator Win32Exception(string message) { return new Win32Exception(message); }
+ }
+
+ public class Hive : IDisposable
+ {
+ private const UInt32 SCOPE = 0x80000002; // HKLM
+ private string hiveKey;
+ private bool loaded = false;
+
+ public Hive(string hiveKey, string hivePath)
+ {
+ this.hiveKey = hiveKey;
+ int ret = NativeMethods.RegLoadKeyW(SCOPE, hiveKey, hivePath);
+ if (ret != 0)
+ throw new Win32Exception(ret, String.Format("Failed to load registry hive at {0}", hivePath));
+ loaded = true;
+ }
+
+ public static void UnloadHive(string hiveKey)
+ {
+ int ret = NativeMethods.RegUnLoadKeyW(SCOPE, hiveKey);
+ if (ret != 0)
+ throw new Win32Exception(ret, String.Format("Failed to unload registry hive at {0}", hiveKey));
+ }
+
+ public void Dispose()
+ {
+ if (loaded)
+ {
+ // Make sure the garbage collector disposes all unused handles and waits until it is complete
+ GC.Collect();
+ GC.WaitForPendingFinalizers();
+
+ UnloadHive(hiveKey);
+ loaded = false;
+ }
+ GC.SuppressFinalize(this);
+ }
+ ~Hive() { this.Dispose(); }
+ }
+}
+'@
+
+# fire a warning if the property name isn't specified, the (Default) key ($null) can only be a string
+if ($null -eq $name -and $type -ne "string") {
+ Add-Warning -obj $result -message "the data type when name is not specified can only be 'string', the type has automatically been converted"
+ $type = "string"
+}
+
+# Check that the registry path is in PSDrive format: HKCC, HKCR, HKCU, HKLM, HKU
+if ($path -notmatch "^HK(CC|CR|CU|LM|U):\\") {
+ Fail-Json $result "path: $path is not a valid powershell path, see module documentation for examples."
+}
+
+# Add a warning if the path does not contains a \ and is not the leaf path
+$registry_path = (Split-Path -Path $path -NoQualifier).Substring(1) # removes the hive: and leading \
+$registry_leaf = Split-Path -Path $path -Leaf
+if ($registry_path -ne $registry_leaf -and -not $registry_path.Contains('\')) {
+ $msg = "path is not using '\' as a separator, support for '/' as a separator will be removed in a future Ansible version"
+ Add-DeprecationWarning -obj $result -message $msg -version 2.12
+ $registry_path = $registry_path.Replace('/', '\')
+}
+
+# Simplified version of Convert-HexStringToByteArray from
+# https://cyber-defense.sans.org/blog/2010/02/11/powershell-byte-array-hex-convert
+# Expects a hex in the format you get when you run reg.exe export,
+# and converts to a byte array so powershell can modify binary registry entries
+# import format is like 'hex:be,ef,be,ef,be,ef,be,ef,be,ef'
+Function Convert-RegExportHexStringToByteArray($string) {
+ # Remove 'hex:' from the front of the string if present
+ $string = $string.ToLower() -replace '^hex\:',''
+
+ # Remove whitespace and any other non-hex crud.
+ $string = $string -replace '[^a-f0-9\\,x\-\:]',''
+
+ # Turn commas into colons
+ $string = $string -replace ',',':'
+
+ # Maybe there's nothing left over to convert...
+ if ($string.Length -eq 0) {
+ return ,@()
+ }
+
+ # Split string with or without colon delimiters.
+ if ($string.Length -eq 1) {
+ return ,@([System.Convert]::ToByte($string,16))
+ } elseif (($string.Length % 2 -eq 0) -and ($string.IndexOf(":") -eq -1)) {
+ return ,@($string -split '([a-f0-9]{2})' | foreach-object { if ($_) {[System.Convert]::ToByte($_,16)}})
+ } elseif ($string.IndexOf(":") -ne -1) {
+ return ,@($string -split ':+' | foreach-object {[System.Convert]::ToByte($_,16)})
+ } else {
+ return ,@()
+ }
+}
+
+Function Compare-RegistryProperties($existing, $new) {
+ # Outputs $true if the property values don't match
+ if ($existing -is [Array]) {
+ (Compare-Object -ReferenceObject $existing -DifferenceObject $new -SyncWindow 0).Length -ne 0
+ } else {
+ $existing -cne $new
+ }
+}
+
+Function Get-DiffValue {
+ param(
+ [Parameter(Mandatory=$true)][Microsoft.Win32.RegistryValueKind]$Type,
+ [Parameter(Mandatory=$true)][Object]$Value
+ )
+
+ $diff = @{ type = $Type.ToString(); value = $Value }
+
+ $enum = [Microsoft.Win32.RegistryValueKind]
+ if ($Type -in @($enum::Binary, $enum::None)) {
+ $diff.value = [System.Collections.Generic.List`1[String]]@()
+ foreach ($dec_value in $Value) {
+ $diff.value.Add("0x{0:x2}" -f $dec_value)
+ }
+ } elseif ($Type -eq $enum::DWord) {
+ $diff.value = "0x{0:x8}" -f $Value
+ } elseif ($Type -eq $enum::QWord) {
+ $diff.value = "0x{0:x16}" -f $Value
+ }
+
+ return $diff
+}
+
+Function Set-StateAbsent {
+ param(
+ # Used for diffs and exception messages to match up against Ansible input
+ [Parameter(Mandatory=$true)][String]$PrintPath,
+ [Parameter(Mandatory=$true)][Microsoft.Win32.RegistryKey]$Hive,
+ [Parameter(Mandatory=$true)][String]$Path,
+ [String]$Name,
+ [Switch]$DeleteKey
+ )
+
+ $key = $Hive.OpenSubKey($Path, $true)
+ if ($null -eq $key) {
+ # Key does not exist, no need to delete anything
+ return
+ }
+
+ try {
+ if ($DeleteKey -and -not $Name) {
+ # delete_key=yes is set and name is null/empty, so delete the entire key
+ $key.Dispose()
+ $key = $null
+ if (-not $check_mode) {
+ try {
+ $Hive.DeleteSubKeyTree($Path, $false)
+ } catch {
+ Fail-Json -obj $result -message "failed to delete registry key at $($PrintPath): $($_.Exception.Message)"
+ }
+ }
+ $result.changed = $true
+
+ if ($diff_mode) {
+ $result.diff.before = @{$PrintPath = @{}}
+ $result.diff.after = @{}
+ }
+ } else {
+ # delete_key=no or name is not null/empty, delete the property not the full key
+ $property = $key.GetValue($Name)
+ if ($null -eq $property) {
+ # property does not exist
+ return
+ }
+ $property_type = $key.GetValueKind($Name) # used for the diff
+
+ if (-not $check_mode) {
+ try {
+ $key.DeleteValue($Name)
+ } catch {
+ Fail-Json -obj $result -message "failed to delete registry property '$Name' at $($PrintPath): $($_.Exception.Message)"
+ }
+ }
+
+ $result.changed = $true
+ if ($diff_mode) {
+ $diff_value = Get-DiffValue -Type $property_type -Value $property
+ $result.diff.before = @{ $PrintPath = @{ $Name = $diff_value } }
+ $result.diff.after = @{ $PrintPath = @{} }
+ }
+ }
+ } finally {
+ if ($key) {
+ $key.Dispose()
+ }
+ }
+}
+
+Function Set-StatePresent {
+ param(
+ [Parameter(Mandatory=$true)][String]$PrintPath,
+ [Parameter(Mandatory=$true)][Microsoft.Win32.RegistryKey]$Hive,
+ [Parameter(Mandatory=$true)][String]$Path,
+ [String]$Name,
+ [Object]$Data,
+ [Microsoft.Win32.RegistryValueKind]$Type
+ )
+
+ $key = $Hive.OpenSubKey($Path, $true)
+ try {
+ if ($null -eq $key) {
+ # the key does not exist, create it so the next steps work
+ if (-not $check_mode) {
+ try {
+ $key = $Hive.CreateSubKey($Path)
+ } catch {
+ Fail-Json -obj $result -message "failed to create registry key at $($PrintPath): $($_.Exception.Message)"
+ }
+ }
+ $result.changed = $true
+
+ if ($diff_mode) {
+ $result.diff.before = @{}
+ $result.diff.after = @{$PrintPath = @{}}
+ }
+ } elseif ($diff_mode) {
+ # Make sure the diff is in an expected state for the key
+ $result.diff.before = @{$PrintPath = @{}}
+ $result.diff.after = @{$PrintPath = @{}}
+ }
+
+ if ($null -eq $key -or $null -eq $Data) {
+ # Check mode and key was created above, we cannot do any more work, or $Data is $null which happens when
+ # we create a new key but haven't explicitly set the data
+ return
+ }
+
+ $property = $key.GetValue($Name, $null, [Microsoft.Win32.RegistryValueOptions]::DoNotExpandEnvironmentNames)
+ if ($null -ne $property) {
+ # property exists, need to compare the values and type
+ $existing_type = $key.GetValueKind($name)
+ $change_value = $false
+
+ if ($Type -ne $existing_type) {
+ $change_value = $true
+ $result.data_type_changed = $true
+ $data_mismatch = Compare-RegistryProperties -existing $property -new $Data
+ if ($data_mismatch) {
+ $result.data_changed = $true
+ }
+ } else {
+ $data_mismatch = Compare-RegistryProperties -existing $property -new $Data
+ if ($data_mismatch) {
+ $change_value = $true
+ $result.data_changed = $true
+ }
+ }
+
+ if ($change_value) {
+ if (-not $check_mode) {
+ try {
+ $key.SetValue($Name, $Data, $Type)
+ } catch {
+ Fail-Json -obj $result -message "failed to change registry property '$Name' at $($PrintPath): $($_.Exception.Message)"
+ }
+ }
+ $result.changed = $true
+
+ if ($diff_mode) {
+ $result.diff.before.$PrintPath.$Name = Get-DiffValue -Type $existing_type -Value $property
+ $result.diff.after.$PrintPath.$Name = Get-DiffValue -Type $Type -Value $Data
+ }
+ } elseif ($diff_mode) {
+ $diff_value = Get-DiffValue -Type $existing_type -Value $property
+ $result.diff.before.$PrintPath.$Name = $diff_value
+ $result.diff.after.$PrintPath.$Name = $diff_value
+ }
+ } else {
+ # property doesn't exist just create a new one
+ if (-not $check_mode) {
+ try {
+ $key.SetValue($Name, $Data, $Type)
+ } catch {
+ Fail-Json -obj $result -message "failed to create registry property '$Name' at $($PrintPath): $($_.Exception.Message)"
+ }
+ }
+ $result.changed = $true
+
+ if ($diff_mode) {
+ $result.diff.after.$PrintPath.$Name = Get-DiffValue -Type $Type -Value $Data
+ }
+ }
+ } finally {
+ if ($key) {
+ $key.Dispose()
+ }
+ }
+}
+
+# convert property names "" to $null as "" refers to (Default)
+if ($name -eq "") {
+ $name = $null
+}
+
+# convert the data to the required format
+if ($type -in @("binary", "none")) {
+ if ($null -eq $data) {
+ $data = ""
+ }
+
+ # convert the data from string to byte array if in hex: format
+ if ($data -is [String]) {
+ $data = [byte[]](Convert-RegExportHexStringToByteArray -string $data)
+ } elseif ($data -is [Int]) {
+ if ($data -gt 255) {
+ Fail-Json $result "cannot convert binary data '$data' to byte array, please specify this value as a yaml byte array or a comma separated hex value string"
+ }
+ $data = [byte[]]@([byte]$data)
+ } elseif ($data -is [Array]) {
+ $data = [byte[]]$data
+ }
+} elseif ($type -in @("dword", "qword")) {
+ # dword's and dword's don't allow null values, set to 0
+ if ($null -eq $data) {
+ $data = 0
+ }
+
+ if ($data -is [String]) {
+ # if the data is a string we need to convert it to an unsigned int64
+ # it needs to be unsigned as Ansible passes in an unsigned value while
+ # powershell uses a signed data type. The value will then be converted
+ # below
+ $data = [UInt64]$data
+ }
+
+ if ($type -eq "dword") {
+ if ($data -gt [UInt32]::MaxValue) {
+ Fail-Json $result "data cannot be larger than 0xffffffff when type is dword"
+ } elseif ($data -gt [Int32]::MaxValue) {
+ # when dealing with larger int32 (> 2147483647 or 0x7FFFFFFF) powershell
+ # automatically converts it to a signed int64. We need to convert this to
+ # signed int32 by parsing the hex string value.
+ $data = "0x$("{0:x}" -f $data)"
+ }
+ $data = [Int32]$data
+ } else {
+ if ($data -gt [UInt64]::MaxValue) {
+ Fail-Json $result "data cannot be larger than 0xffffffffffffffff when type is qword"
+ } elseif ($data -gt [Int64]::MaxValue) {
+ $data = "0x$("{0:x}" -f $data)"
+ }
+ $data = [Int64]$data
+ }
+} elseif ($type -in @("string", "expandstring") -and $name) {
+ # a null string or expandstring must be empty quotes
+ # Only do this if $name has been defined (not the default key)
+ if ($null -eq $data) {
+ $data = ""
+ }
+} elseif ($type -eq "multistring") {
+ # convert the data for a multistring to a String[] array
+ if ($null -eq $data) {
+ $data = [String[]]@()
+ } elseif ($data -isnot [Array]) {
+ $new_data = New-Object -TypeName String[] -ArgumentList 1
+ $new_data[0] = $data.ToString([CultureInfo]::InvariantCulture)
+ $data = $new_data
+ } else {
+ $new_data = New-Object -TypeName String[] -ArgumentList $data.Count
+ foreach ($entry in $data) {
+ $new_data[$data.IndexOf($entry)] = $entry.ToString([CultureInfo]::InvariantCulture)
+ }
+ $data = $new_data
+ }
+}
+
+# convert the type string to the .NET class
+$type = [System.Enum]::Parse([Microsoft.Win32.RegistryValueKind], $type, $true)
+
+$registry_hive = switch(Split-Path -Path $path -Qualifier) {
+ "HKCR:" { [Microsoft.Win32.Registry]::ClassesRoot }
+ "HKCC:" { [Microsoft.Win32.Registry]::CurrentConfig }
+ "HKCU:" { [Microsoft.Win32.Registry]::CurrentUser }
+ "HKLM:" { [Microsoft.Win32.Registry]::LocalMachine }
+ "HKU:" { [Microsoft.Win32.Registry]::Users }
+}
+$loaded_hive = $null
+try {
+ if ($hive) {
+ if (-not (Test-Path -LiteralPath $hive)) {
+ Fail-Json -obj $result -message "hive at path '$hive' is not valid or accessible, cannot load hive"
+ }
+
+ $original_tmp = $env:TMP
+ $env:TMP = $_remote_tmp
+ Add-Type -TypeDefinition $registry_util
+ $env:TMP = $original_tmp
+
+ try {
+ Set-AnsiblePrivilege -Name SeBackupPrivilege -Value $true
+ Set-AnsiblePrivilege -Name SeRestorePrivilege -Value $true
+ } catch [System.ComponentModel.Win32Exception] {
+ Fail-Json -obj $result -message "failed to enable SeBackupPrivilege and SeRestorePrivilege for the current process: $($_.Exception.Message)"
+ }
+
+ if (Test-Path -Path HKLM:\ANSIBLE) {
+ Add-Warning -obj $result -message "hive already loaded at HKLM:\ANSIBLE, had to unload hive for win_regedit to continue"
+ try {
+ [Ansible.WinRegedit.Hive]::UnloadHive("ANSIBLE")
+ } catch [System.ComponentModel.Win32Exception] {
+ Fail-Json -obj $result -message "failed to unload registry hive HKLM:\ANSIBLE from $($hive): $($_.Exception.Message)"
+ }
+ }
+
+ try {
+ $loaded_hive = New-Object -TypeName Ansible.WinRegedit.Hive -ArgumentList "ANSIBLE", $hive
+ } catch [System.ComponentModel.Win32Exception] {
+ Fail-Json -obj $result -message "failed to load registry hive from '$hive' to HKLM:\ANSIBLE: $($_.Exception.Message)"
+ }
+ }
+
+ if ($state -eq "present") {
+ Set-StatePresent -PrintPath $path -Hive $registry_hive -Path $registry_path -Name $name -Data $data -Type $type
+ } else {
+ Set-StateAbsent -PrintPath $path -Hive $registry_hive -Path $registry_path -Name $name -DeleteKey:$delete_key
+ }
+} finally {
+ $registry_hive.Dispose()
+ if ($loaded_hive) {
+ $loaded_hive.Dispose()
+ }
+}
+
+Exit-Json $result
+
diff --git a/test/support/windows-integration/plugins/modules/win_regedit.py b/test/support/windows-integration/plugins/modules/win_regedit.py
new file mode 100644
index 00000000..2c0fff71
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_regedit.py
@@ -0,0 +1,210 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2015, Adam Keech <akeech@chathamfinancial.com>
+# Copyright: (c) 2015, Josh Ludwig <jludwig@chathamfinancial.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub. actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'core'}
+
+
+DOCUMENTATION = r'''
+---
+module: win_regedit
+version_added: '2.0'
+short_description: Add, change, or remove registry keys and values
+description:
+- Add, modify or remove registry keys and values.
+- More information about the windows registry from Wikipedia
+ U(https://en.wikipedia.org/wiki/Windows_Registry).
+options:
+ path:
+ description:
+ - Name of the registry path.
+ - 'Should be in one of the following registry hives: HKCC, HKCR, HKCU,
+ HKLM, HKU.'
+ type: str
+ required: yes
+ aliases: [ key ]
+ name:
+ description:
+ - Name of the registry entry in the above C(path) parameters.
+ - If not provided, or empty then the '(Default)' property for the key will
+ be used.
+ type: str
+ aliases: [ entry, value ]
+ data:
+ description:
+ - Value of the registry entry C(name) in C(path).
+ - If not specified then the value for the property will be null for the
+ corresponding C(type).
+ - Binary and None data should be expressed in a yaml byte array or as comma
+ separated hex values.
+ - An easy way to generate this is to run C(regedit.exe) and use the
+ I(export) option to save the registry values to a file.
+ - In the exported file, binary value will look like C(hex:be,ef,be,ef), the
+ C(hex:) prefix is optional.
+ - DWORD and QWORD values should either be represented as a decimal number
+ or a hex value.
+ - Multistring values should be passed in as a list.
+ - See the examples for more details on how to format this data.
+ type: str
+ type:
+ description:
+ - The registry value data type.
+ type: str
+ choices: [ binary, dword, expandstring, multistring, string, qword ]
+ default: string
+ aliases: [ datatype ]
+ state:
+ description:
+ - The state of the registry entry.
+ type: str
+ choices: [ absent, present ]
+ default: present
+ delete_key:
+ description:
+ - When C(state) is 'absent' then this will delete the entire key.
+ - If C(no) then it will only clear out the '(Default)' property for
+ that key.
+ type: bool
+ default: yes
+ version_added: '2.4'
+ hive:
+ description:
+ - A path to a hive key like C:\Users\Default\NTUSER.DAT to load in the
+ registry.
+ - This hive is loaded under the HKLM:\ANSIBLE key which can then be used
+ in I(name) like any other path.
+ - This can be used to load the default user profile registry hive or any
+ other hive saved as a file.
+ - Using this function requires the user to have the C(SeRestorePrivilege)
+ and C(SeBackupPrivilege) privileges enabled.
+ type: path
+ version_added: '2.5'
+notes:
+- Check-mode C(-C/--check) and diff output C(-D/--diff) are supported, so that you can test every change against the active configuration before
+ applying changes.
+- Beware that some registry hives (C(HKEY_USERS) in particular) do not allow to create new registry paths in the root folder.
+- Since ansible 2.4, when checking if a string registry value has changed, a case-sensitive test is used. Previously the test was case-insensitive.
+seealso:
+- module: win_reg_stat
+- module: win_regmerge
+author:
+- Adam Keech (@smadam813)
+- Josh Ludwig (@joshludwig)
+- Jordan Borean (@jborean93)
+'''
+
+EXAMPLES = r'''
+- name: Create registry path MyCompany
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+
+- name: Add or update registry path MyCompany, with entry 'hello', and containing 'world'
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ name: hello
+ data: world
+
+- name: Add or update registry path MyCompany, with dword entry 'hello', and containing 1337 as the decimal value
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ name: hello
+ data: 1337
+ type: dword
+
+- name: Add or update registry path MyCompany, with dword entry 'hello', and containing 0xff2500ae as the hex value
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ name: hello
+ data: 0xff2500ae
+ type: dword
+
+- name: Add or update registry path MyCompany, with binary entry 'hello', and containing binary data in hex-string format
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ name: hello
+ data: hex:be,ef,be,ef,be,ef,be,ef,be,ef
+ type: binary
+
+- name: Add or update registry path MyCompany, with binary entry 'hello', and containing binary data in yaml format
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ name: hello
+ data: [0xbe,0xef,0xbe,0xef,0xbe,0xef,0xbe,0xef,0xbe,0xef]
+ type: binary
+
+- name: Add or update registry path MyCompany, with expand string entry 'hello'
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ name: hello
+ data: '%appdata%\local'
+ type: expandstring
+
+- name: Add or update registry path MyCompany, with multi string entry 'hello'
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ name: hello
+ data: ['hello', 'world']
+ type: multistring
+
+- name: Disable keyboard layout hotkey for all users (changes existing)
+ win_regedit:
+ path: HKU:\.DEFAULT\Keyboard Layout\Toggle
+ name: Layout Hotkey
+ data: 3
+ type: dword
+
+- name: Disable language hotkey for current users (adds new)
+ win_regedit:
+ path: HKCU:\Keyboard Layout\Toggle
+ name: Language Hotkey
+ data: 3
+ type: dword
+
+- name: Remove registry path MyCompany (including all entries it contains)
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ state: absent
+ delete_key: yes
+
+- name: Clear the existing (Default) entry at path MyCompany
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ state: absent
+ delete_key: no
+
+- name: Remove entry 'hello' from registry path MyCompany
+ win_regedit:
+ path: HKCU:\Software\MyCompany
+ name: hello
+ state: absent
+
+- name: Change default mouse trailing settings for new users
+ win_regedit:
+ path: HKLM:\ANSIBLE\Control Panel\Mouse
+ name: MouseTrails
+ data: 10
+ type: str
+ state: present
+ hive: C:\Users\Default\NTUSER.dat
+'''
+
+RETURN = r'''
+data_changed:
+ description: Whether this invocation changed the data in the registry value.
+ returned: success
+ type: bool
+ sample: false
+data_type_changed:
+ description: Whether this invocation changed the datatype of the registry value.
+ returned: success
+ type: bool
+ sample: true
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_security_policy.ps1 b/test/support/windows-integration/plugins/modules/win_security_policy.ps1
new file mode 100644
index 00000000..274204b6
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_security_policy.ps1
@@ -0,0 +1,196 @@
+#!powershell
+
+# Copyright: (c) 2017, Jordan Borean <jborean93@gmail.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+
+$ErrorActionPreference = 'Stop'
+
+$params = Parse-Args $args -supports_check_mode $true
+$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false
+$diff_mode = Get-AnsibleParam -obj $Params -name "_ansible_diff" -type "bool" -default $false
+
+$section = Get-AnsibleParam -obj $params -name "section" -type "str" -failifempty $true
+$key = Get-AnsibleParam -obj $params -name "key" -type "str" -failifempty $true
+$value = Get-AnsibleParam -obj $params -name "value" -failifempty $true
+
+$result = @{
+ changed = $false
+ section = $section
+ key = $key
+ value = $value
+}
+
+if ($diff_mode) {
+ $result.diff = @{}
+}
+
+Function Run-SecEdit($arguments) {
+ $stdout = $null
+ $stderr = $null
+ $log_path = [IO.Path]::GetTempFileName()
+ $arguments = $arguments + @("/log", $log_path)
+
+ try {
+ $stdout = &SecEdit.exe $arguments | Out-String
+ } catch {
+ $stderr = $_.Exception.Message
+ }
+ $log = Get-Content -Path $log_path
+ Remove-Item -Path $log_path -Force
+
+ $return = @{
+ log = ($log -join "`n").Trim()
+ stdout = $stdout
+ stderr = $stderr
+ rc = $LASTEXITCODE
+ }
+
+ return $return
+}
+
+Function Export-SecEdit() {
+ $secedit_ini_path = [IO.Path]::GetTempFileName()
+ # while this will technically make a change to the system in check mode by
+ # creating a new file, we need these values to be able to do anything
+ # substantial in check mode
+ $export_result = Run-SecEdit -arguments @("/export", "/cfg", $secedit_ini_path, "/quiet")
+
+ # check the return code and if the file has been populated, otherwise error out
+ if (($export_result.rc -ne 0) -or ((Get-Item -Path $secedit_ini_path).Length -eq 0)) {
+ Remove-Item -Path $secedit_ini_path -Force
+ $result.rc = $export_result.rc
+ $result.stdout = $export_result.stdout
+ $result.stderr = $export_result.stderr
+ Fail-Json $result "Failed to export secedit.ini file to $($secedit_ini_path)"
+ }
+ $secedit_ini = ConvertFrom-Ini -file_path $secedit_ini_path
+
+ return $secedit_ini
+}
+
+Function Import-SecEdit($ini) {
+ $secedit_ini_path = [IO.Path]::GetTempFileName()
+ $secedit_db_path = [IO.Path]::GetTempFileName()
+ Remove-Item -Path $secedit_db_path -Force # needs to be deleted for SecEdit.exe /import to work
+
+ $ini_contents = ConvertTo-Ini -ini $ini
+ Set-Content -Path $secedit_ini_path -Value $ini_contents
+ $result.changed = $true
+
+ $import_result = Run-SecEdit -arguments @("/configure", "/db", $secedit_db_path, "/cfg", $secedit_ini_path, "/quiet")
+ $result.import_log = $import_result.log
+ Remove-Item -Path $secedit_ini_path -Force
+ if ($import_result.rc -ne 0) {
+ $result.rc = $import_result.rc
+ $result.stdout = $import_result.stdout
+ $result.stderr = $import_result.stderr
+ Fail-Json $result "Failed to import secedit.ini file from $($secedit_ini_path)"
+ }
+}
+
+Function ConvertTo-Ini($ini) {
+ $content = @()
+ foreach ($key in $ini.GetEnumerator()) {
+ $section = $key.Name
+ $values = $key.Value
+
+ $content += "[$section]"
+ foreach ($value in $values.GetEnumerator()) {
+ $value_key = $value.Name
+ $value_value = $value.Value
+
+ if ($null -ne $value_value) {
+ $content += "$value_key = $value_value"
+ }
+ }
+ }
+
+ return $content -join "`r`n"
+}
+
+Function ConvertFrom-Ini($file_path) {
+ $ini = @{}
+ switch -Regex -File $file_path {
+ "^\[(.+)\]" {
+ $section = $matches[1]
+ $ini.$section = @{}
+ }
+ "(.+?)\s*=(.*)" {
+ $name = $matches[1].Trim()
+ $value = $matches[2].Trim()
+ if ($value -match "^\d+$") {
+ $value = [int]$value
+ } elseif ($value.StartsWith('"') -and $value.EndsWith('"')) {
+ $value = $value.Substring(1, $value.Length - 2)
+ }
+
+ $ini.$section.$name = $value
+ }
+ }
+
+ return $ini
+}
+
+if ($section -eq "Privilege Rights") {
+ Add-Warning -obj $result -message "Using this module to edit rights and privileges is error-prone, use the win_user_right module instead"
+}
+
+$will_change = $false
+$secedit_ini = Export-SecEdit
+if (-not ($secedit_ini.ContainsKey($section))) {
+ Fail-Json $result "The section '$section' does not exist in SecEdit.exe output ini"
+}
+
+if ($secedit_ini.$section.ContainsKey($key)) {
+ $current_value = $secedit_ini.$section.$key
+
+ if ($current_value -cne $value) {
+ if ($diff_mode) {
+ $result.diff.prepared = @"
+[$section]
+-$key = $current_value
++$key = $value
+"@
+ }
+
+ $secedit_ini.$section.$key = $value
+ $will_change = $true
+ }
+} elseif ([string]$value -eq "") {
+ # Value is requested to be removed, and has already been removed, do nothing
+} else {
+ if ($diff_mode) {
+ $result.diff.prepared = @"
+[$section]
++$key = $value
+"@
+ }
+ $secedit_ini.$section.$key = $value
+ $will_change = $true
+}
+
+if ($will_change -eq $true) {
+ $result.changed = $true
+ if (-not $check_mode) {
+ Import-SecEdit -ini $secedit_ini
+
+ # secedit doesn't error out on improper entries, re-export and verify
+ # the changes occurred
+ $verification_ini = Export-SecEdit
+ $new_section_values = $verification_ini.$section
+ if ($new_section_values.ContainsKey($key)) {
+ $new_value = $new_section_values.$key
+ if ($new_value -cne $value) {
+ Fail-Json $result "Failed to change the value for key '$key' in section '$section', the value is still $new_value"
+ }
+ } elseif ([string]$value -eq "") {
+ # Value was empty, so OK if no longer in the result
+ } else {
+ Fail-Json $result "The key '$key' in section '$section' is not a valid key, cannot set this value"
+ }
+ }
+}
+
+Exit-Json $result
diff --git a/test/support/windows-integration/plugins/modules/win_security_policy.py b/test/support/windows-integration/plugins/modules/win_security_policy.py
new file mode 100644
index 00000000..d582a532
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_security_policy.py
@@ -0,0 +1,126 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub, actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_security_policy
+version_added: '2.4'
+short_description: Change local security policy settings
+description:
+- Allows you to set the local security policies that are configured by
+ SecEdit.exe.
+options:
+ section:
+ description:
+ - The ini section the key exists in.
+ - If the section does not exist then the module will return an error.
+ - Example sections to use are 'Account Policies', 'Local Policies',
+ 'Event Log', 'Restricted Groups', 'System Services', 'Registry' and
+ 'File System'
+ - If wanting to edit the C(Privilege Rights) section, use the
+ M(win_user_right) module instead.
+ type: str
+ required: yes
+ key:
+ description:
+ - The ini key of the section or policy name to modify.
+ - The module will return an error if this key is invalid.
+ type: str
+ required: yes
+ value:
+ description:
+ - The value for the ini key or policy name.
+ - If the key takes in a boolean value then 0 = False and 1 = True.
+ type: str
+ required: yes
+notes:
+- This module uses the SecEdit.exe tool to configure the values, more details
+ of the areas and keys that can be configured can be found here
+ U(https://msdn.microsoft.com/en-us/library/bb742512.aspx).
+- If you are in a domain environment these policies may be set by a GPO policy,
+ this module can temporarily change these values but the GPO will override
+ it if the value differs.
+- You can also run C(SecEdit.exe /export /cfg C:\temp\output.ini) to view the
+ current policies set on your system.
+- When assigning user rights, use the M(win_user_right) module instead.
+seealso:
+- module: win_user_right
+author:
+- Jordan Borean (@jborean93)
+'''
+
+EXAMPLES = r'''
+- name: Change the guest account name
+ win_security_policy:
+ section: System Access
+ key: NewGuestName
+ value: Guest Account
+
+- name: Set the maximum password age
+ win_security_policy:
+ section: System Access
+ key: MaximumPasswordAge
+ value: 15
+
+- name: Do not store passwords using reversible encryption
+ win_security_policy:
+ section: System Access
+ key: ClearTextPassword
+ value: 0
+
+- name: Enable system events
+ win_security_policy:
+ section: Event Audit
+ key: AuditSystemEvents
+ value: 1
+'''
+
+RETURN = r'''
+rc:
+ description: The return code after a failure when running SecEdit.exe.
+ returned: failure with secedit calls
+ type: int
+ sample: -1
+stdout:
+ description: The output of the STDOUT buffer after a failure when running
+ SecEdit.exe.
+ returned: failure with secedit calls
+ type: str
+ sample: check log for error details
+stderr:
+ description: The output of the STDERR buffer after a failure when running
+ SecEdit.exe.
+ returned: failure with secedit calls
+ type: str
+ sample: failed to import security policy
+import_log:
+ description: The log of the SecEdit.exe /configure job that configured the
+ local policies. This is used for debugging purposes on failures.
+ returned: secedit.exe /import run and change occurred
+ type: str
+ sample: Completed 6 percent (0/15) \tProcess Privilege Rights area.
+key:
+ description: The key in the section passed to the module to modify.
+ returned: success
+ type: str
+ sample: NewGuestName
+section:
+ description: The section passed to the module to modify.
+ returned: success
+ type: str
+ sample: System Access
+value:
+ description: The value passed to the module to modify to.
+ returned: success
+ type: str
+ sample: Guest Account
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_shell.ps1 b/test/support/windows-integration/plugins/modules/win_shell.ps1
new file mode 100644
index 00000000..54aef8de
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_shell.ps1
@@ -0,0 +1,138 @@
+#!powershell
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+#Requires -Module Ansible.ModuleUtils.CommandUtil
+#Requires -Module Ansible.ModuleUtils.FileUtil
+
+# TODO: add check mode support
+
+Set-StrictMode -Version 2
+$ErrorActionPreference = "Stop"
+
+# Cleanse CLIXML from stderr (sift out error stream data, discard others for now)
+Function Cleanse-Stderr($raw_stderr) {
+ Try {
+ # NB: this regex isn't perfect, but is decent at finding CLIXML amongst other stderr noise
+ If($raw_stderr -match "(?s)(?<prenoise1>.*)#< CLIXML(?<prenoise2>.*)(?<clixml><Objs.+</Objs>)(?<postnoise>.*)") {
+ $clixml = [xml]$matches["clixml"]
+
+ $merged_stderr = "{0}{1}{2}{3}" -f @(
+ $matches["prenoise1"],
+ $matches["prenoise2"],
+ # filter out just the Error-tagged strings for now, and zap embedded CRLF chars
+ ($clixml.Objs.ChildNodes | Where-Object { $_.Name -eq 'S' } | Where-Object { $_.S -eq 'Error' } | ForEach-Object { $_.'#text'.Replace('_x000D__x000A_','') } | Out-String),
+ $matches["postnoise"]) | Out-String
+
+ return $merged_stderr.Trim()
+
+ # FUTURE: parse/return other streams
+ }
+ Else {
+ $raw_stderr
+ }
+ }
+ Catch {
+ "***EXCEPTION PARSING CLIXML: $_***" + $raw_stderr
+ }
+}
+
+$params = Parse-Args $args -supports_check_mode $false
+
+$raw_command_line = Get-AnsibleParam -obj $params -name "_raw_params" -type "str" -failifempty $true
+$chdir = Get-AnsibleParam -obj $params -name "chdir" -type "path"
+$executable = Get-AnsibleParam -obj $params -name "executable" -type "path"
+$creates = Get-AnsibleParam -obj $params -name "creates" -type "path"
+$removes = Get-AnsibleParam -obj $params -name "removes" -type "path"
+$stdin = Get-AnsibleParam -obj $params -name "stdin" -type "str"
+$no_profile = Get-AnsibleParam -obj $params -name "no_profile" -type "bool" -default $false
+$output_encoding_override = Get-AnsibleParam -obj $params -name "output_encoding_override" -type "str"
+
+$raw_command_line = $raw_command_line.Trim()
+
+$result = @{
+ changed = $true
+ cmd = $raw_command_line
+}
+
+if ($creates -and $(Test-AnsiblePath -Path $creates)) {
+ Exit-Json @{msg="skipped, since $creates exists";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0}
+}
+
+if ($removes -and -not $(Test-AnsiblePath -Path $removes)) {
+ Exit-Json @{msg="skipped, since $removes does not exist";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0}
+}
+
+$exec_args = $null
+If(-not $executable -or $executable -eq "powershell") {
+ $exec_application = "powershell.exe"
+
+ # force input encoding to preamble-free UTF8 so PS sub-processes (eg, Start-Job) don't blow up
+ $raw_command_line = "[Console]::InputEncoding = New-Object Text.UTF8Encoding `$false; " + $raw_command_line
+
+ # Base64 encode the command so we don't have to worry about the various levels of escaping
+ $encoded_command = [Convert]::ToBase64String([System.Text.Encoding]::Unicode.GetBytes($raw_command_line))
+
+ if ($stdin) {
+ $exec_args = "-encodedcommand $encoded_command"
+ } else {
+ $exec_args = "-noninteractive -encodedcommand $encoded_command"
+ }
+
+ if ($no_profile) {
+ $exec_args = "-noprofile $exec_args"
+ }
+}
+Else {
+ # FUTURE: support arg translation from executable (or executable_args?) to process arguments for arbitrary interpreter?
+ $exec_application = $executable
+ if (-not ($exec_application.EndsWith(".exe"))) {
+ $exec_application = "$($exec_application).exe"
+ }
+ $exec_args = "/c $raw_command_line"
+}
+
+$command = "`"$exec_application`" $exec_args"
+$run_command_arg = @{
+ command = $command
+}
+if ($chdir) {
+ $run_command_arg['working_directory'] = $chdir
+}
+if ($stdin) {
+ $run_command_arg['stdin'] = $stdin
+}
+if ($output_encoding_override) {
+ $run_command_arg['output_encoding_override'] = $output_encoding_override
+}
+
+$start_datetime = [DateTime]::UtcNow
+try {
+ $command_result = Run-Command @run_command_arg
+} catch {
+ $result.changed = $false
+ try {
+ $result.rc = $_.Exception.NativeErrorCode
+ } catch {
+ $result.rc = 2
+ }
+ Fail-Json -obj $result -message $_.Exception.Message
+}
+
+# TODO: decode CLIXML stderr output (and other streams?)
+$result.stdout = $command_result.stdout
+$result.stderr = Cleanse-Stderr $command_result.stderr
+$result.rc = $command_result.rc
+
+$end_datetime = [DateTime]::UtcNow
+$result.start = $start_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff")
+$result.end = $end_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff")
+$result.delta = $($end_datetime - $start_datetime).ToString("h\:mm\:ss\.ffffff")
+
+If ($result.rc -ne 0) {
+ Fail-Json -obj $result -message "non-zero return code"
+}
+
+Exit-Json $result
diff --git a/test/support/windows-integration/plugins/modules/win_shell.py b/test/support/windows-integration/plugins/modules/win_shell.py
new file mode 100644
index 00000000..ee2cd762
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_shell.py
@@ -0,0 +1,167 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2016, Ansible, inc
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_shell
+short_description: Execute shell commands on target hosts
+version_added: 2.2
+description:
+ - The C(win_shell) module takes the command name followed by a list of space-delimited arguments.
+ It is similar to the M(win_command) module, but runs
+ the command via a shell (defaults to PowerShell) on the target host.
+ - For non-Windows targets, use the M(shell) module instead.
+options:
+ free_form:
+ description:
+ - The C(win_shell) module takes a free form command to run.
+ - There is no parameter actually named 'free form'. See the examples!
+ type: str
+ required: yes
+ creates:
+ description:
+ - A path or path filter pattern; when the referenced path exists on the target host, the task will be skipped.
+ type: path
+ removes:
+ description:
+ - A path or path filter pattern; when the referenced path B(does not) exist on the target host, the task will be skipped.
+ type: path
+ chdir:
+ description:
+ - Set the specified path as the current working directory before executing a command
+ type: path
+ executable:
+ description:
+ - Change the shell used to execute the command (eg, C(cmd)).
+ - The target shell must accept a C(/c) parameter followed by the raw command line to be executed.
+ type: path
+ stdin:
+ description:
+ - Set the stdin of the command directly to the specified value.
+ type: str
+ version_added: '2.5'
+ no_profile:
+ description:
+ - Do not load the user profile before running a command. This is only valid
+ when using PowerShell as the executable.
+ type: bool
+ default: no
+ version_added: '2.8'
+ output_encoding_override:
+ description:
+ - This option overrides the encoding of stdout/stderr output.
+ - You can use this option when you need to run a command which ignore the console's codepage.
+ - You should only need to use this option in very rare circumstances.
+ - This value can be any valid encoding C(Name) based on the output of C([System.Text.Encoding]::GetEncodings()).
+ See U(https://docs.microsoft.com/dotnet/api/system.text.encoding.getencodings).
+ type: str
+ version_added: '2.10'
+notes:
+ - If you want to run an executable securely and predictably, it may be
+ better to use the M(win_command) module instead. Best practices when writing
+ playbooks will follow the trend of using M(win_command) unless C(win_shell) is
+ explicitly required. When running ad-hoc commands, use your best judgement.
+ - WinRM will not return from a command execution until all child processes created have exited.
+ Thus, it is not possible to use C(win_shell) to spawn long-running child or background processes.
+ Consider creating a Windows service for managing background processes.
+seealso:
+- module: psexec
+- module: raw
+- module: script
+- module: shell
+- module: win_command
+- module: win_psexec
+author:
+ - Matt Davis (@nitzmahone)
+'''
+
+EXAMPLES = r'''
+# Execute a command in the remote shell; stdout goes to the specified
+# file on the remote.
+- win_shell: C:\somescript.ps1 >> C:\somelog.txt
+
+# Change the working directory to somedir/ before executing the command.
+- win_shell: C:\somescript.ps1 >> C:\somelog.txt chdir=C:\somedir
+
+# You can also use the 'args' form to provide the options. This command
+# will change the working directory to somedir/ and will only run when
+# somedir/somelog.txt doesn't exist.
+- win_shell: C:\somescript.ps1 >> C:\somelog.txt
+ args:
+ chdir: C:\somedir
+ creates: C:\somelog.txt
+
+# Run a command under a non-Powershell interpreter (cmd in this case)
+- win_shell: echo %HOMEDIR%
+ args:
+ executable: cmd
+ register: homedir_out
+
+- name: Run multi-lined shell commands
+ win_shell: |
+ $value = Test-Path -Path C:\temp
+ if ($value) {
+ Remove-Item -Path C:\temp -Force
+ }
+ New-Item -Path C:\temp -ItemType Directory
+
+- name: Retrieve the input based on stdin
+ win_shell: '$string = [Console]::In.ReadToEnd(); Write-Output $string.Trim()'
+ args:
+ stdin: Input message
+'''
+
+RETURN = r'''
+msg:
+ description: Changed.
+ returned: always
+ type: bool
+ sample: true
+start:
+ description: The command execution start time.
+ returned: always
+ type: str
+ sample: '2016-02-25 09:18:26.429568'
+end:
+ description: The command execution end time.
+ returned: always
+ type: str
+ sample: '2016-02-25 09:18:26.755339'
+delta:
+ description: The command execution delta time.
+ returned: always
+ type: str
+ sample: '0:00:00.325771'
+stdout:
+ description: The command standard output.
+ returned: always
+ type: str
+ sample: 'Clustering node rabbit@slave1 with rabbit@master ...'
+stderr:
+ description: The command standard error.
+ returned: always
+ type: str
+ sample: 'ls: cannot access foo: No such file or directory'
+cmd:
+ description: The command executed by the task.
+ returned: always
+ type: str
+ sample: 'rabbitmqctl join_cluster rabbit@master'
+rc:
+ description: The command return code (0 means success).
+ returned: always
+ type: int
+ sample: 0
+stdout_lines:
+ description: The command standard output split in lines.
+ returned: always
+ type: list
+ sample: [u'Clustering node rabbit@slave1 with rabbit@master ...']
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_stat.ps1 b/test/support/windows-integration/plugins/modules/win_stat.ps1
new file mode 100644
index 00000000..071eb11c
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_stat.ps1
@@ -0,0 +1,186 @@
+#!powershell
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+#Requires -Module Ansible.ModuleUtils.FileUtil
+#Requires -Module Ansible.ModuleUtils.LinkUtil
+
+function ConvertTo-Timestamp($start_date, $end_date) {
+ if ($start_date -and $end_date) {
+ return (New-TimeSpan -Start $start_date -End $end_date).TotalSeconds
+ }
+}
+
+function Get-FileChecksum($path, $algorithm) {
+ switch ($algorithm) {
+ 'md5' { $sp = New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider }
+ 'sha1' { $sp = New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider }
+ 'sha256' { $sp = New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider }
+ 'sha384' { $sp = New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider }
+ 'sha512' { $sp = New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider }
+ default { Fail-Json -obj $result -message "Unsupported hash algorithm supplied '$algorithm'" }
+ }
+
+ $fp = [System.IO.File]::Open($path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read, [System.IO.FileShare]::ReadWrite)
+ try {
+ $hash = [System.BitConverter]::ToString($sp.ComputeHash($fp)).Replace("-", "").ToLower()
+ } finally {
+ $fp.Dispose()
+ }
+
+ return $hash
+}
+
+function Get-FileInfo {
+ param([String]$Path, [Switch]$Follow)
+
+ $info = Get-AnsibleItem -Path $Path -ErrorAction SilentlyContinue
+ $link_info = $null
+ if ($null -ne $info) {
+ try {
+ $link_info = Get-Link -link_path $info.FullName
+ } catch {
+ $module.Warn("Failed to check/get link info for file: $($_.Exception.Message)")
+ }
+
+ # If follow=true we want to follow the link all the way back to root object
+ if ($Follow -and $null -ne $link_info -and $link_info.Type -in @("SymbolicLink", "JunctionPoint")) {
+ $info, $link_info = Get-FileInfo -Path $link_info.AbsolutePath -Follow
+ }
+ }
+
+ return $info, $link_info
+}
+
+$spec = @{
+ options = @{
+ path = @{ type='path'; required=$true; aliases=@( 'dest', 'name' ) }
+ get_checksum = @{ type='bool'; default=$true }
+ checksum_algorithm = @{ type='str'; default='sha1'; choices=@( 'md5', 'sha1', 'sha256', 'sha384', 'sha512' ) }
+ follow = @{ type='bool'; default=$false }
+ }
+ supports_check_mode = $true
+}
+
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec)
+
+$path = $module.Params.path
+$get_checksum = $module.Params.get_checksum
+$checksum_algorithm = $module.Params.checksum_algorithm
+$follow = $module.Params.follow
+
+$module.Result.stat = @{ exists=$false }
+
+Load-LinkUtils
+$info, $link_info = Get-FileInfo -Path $path -Follow:$follow
+If ($null -ne $info) {
+ $epoch_date = Get-Date -Date "01/01/1970"
+ $attributes = @()
+ foreach ($attribute in ($info.Attributes -split ',')) {
+ $attributes += $attribute.Trim()
+ }
+
+ # default values that are always set, specific values are set below this
+ # but are kept commented for easier readability
+ $stat = @{
+ exists = $true
+ attributes = $info.Attributes.ToString()
+ isarchive = ($attributes -contains "Archive")
+ isdir = $false
+ ishidden = ($attributes -contains "Hidden")
+ isjunction = $false
+ islnk = $false
+ isreadonly = ($attributes -contains "ReadOnly")
+ isreg = $false
+ isshared = $false
+ nlink = 1 # Number of links to the file (hard links), overriden below if islnk
+ # lnk_target = islnk or isjunction Target of the symlink. Note that relative paths remain relative
+ # lnk_source = islnk os isjunction Target of the symlink normalized for the remote filesystem
+ hlnk_targets = @()
+ creationtime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.CreationTime)
+ lastaccesstime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.LastAccessTime)
+ lastwritetime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.LastWriteTime)
+ # size = a file and directory - calculated below
+ path = $info.FullName
+ filename = $info.Name
+ # extension = a file
+ # owner = set outsite this dict in case it fails
+ # sharename = a directory and isshared is True
+ # checksum = a file and get_checksum: True
+ }
+ try {
+ $stat.owner = $info.GetAccessControl().Owner
+ } catch {
+ # may not have rights, historical behaviour was to just set to $null
+ # due to ErrorActionPreference being set to "Continue"
+ $stat.owner = $null
+ }
+
+ # values that are set according to the type of file
+ if ($info.Attributes.HasFlag([System.IO.FileAttributes]::Directory)) {
+ $stat.isdir = $true
+ $share_info = Get-CimInstance -ClassName Win32_Share -Filter "Path='$($stat.path -replace '\\', '\\')'"
+ if ($null -ne $share_info) {
+ $stat.isshared = $true
+ $stat.sharename = $share_info.Name
+ }
+
+ try {
+ $size = 0
+ foreach ($file in $info.EnumerateFiles("*", [System.IO.SearchOption]::AllDirectories)) {
+ $size += $file.Length
+ }
+ $stat.size = $size
+ } catch {
+ $stat.size = 0
+ }
+ } else {
+ $stat.extension = $info.Extension
+ $stat.isreg = $true
+ $stat.size = $info.Length
+
+ if ($get_checksum) {
+ try {
+ $stat.checksum = Get-FileChecksum -path $path -algorithm $checksum_algorithm
+ } catch {
+ $module.FailJson("Failed to get hash of file, set get_checksum to False to ignore this error: $($_.Exception.Message)", $_)
+ }
+ }
+ }
+
+ # Get symbolic link, junction point, hard link info
+ if ($null -ne $link_info) {
+ switch ($link_info.Type) {
+ "SymbolicLink" {
+ $stat.islnk = $true
+ $stat.isreg = $false
+ $stat.lnk_target = $link_info.TargetPath
+ $stat.lnk_source = $link_info.AbsolutePath
+ break
+ }
+ "JunctionPoint" {
+ $stat.isjunction = $true
+ $stat.isreg = $false
+ $stat.lnk_target = $link_info.TargetPath
+ $stat.lnk_source = $link_info.AbsolutePath
+ break
+ }
+ "HardLink" {
+ $stat.lnk_type = "hard"
+ $stat.nlink = $link_info.HardTargets.Count
+
+ # remove current path from the targets
+ $hlnk_targets = $link_info.HardTargets | Where-Object { $_ -ne $stat.path }
+ $stat.hlnk_targets = @($hlnk_targets)
+ break
+ }
+ }
+ }
+
+ $module.Result.stat = $stat
+}
+
+$module.ExitJson()
+
diff --git a/test/support/windows-integration/plugins/modules/win_stat.py b/test/support/windows-integration/plugins/modules/win_stat.py
new file mode 100644
index 00000000..0676b5b2
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_stat.py
@@ -0,0 +1,236 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub. actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_stat
+version_added: "1.7"
+short_description: Get information about Windows files
+description:
+ - Returns information about a Windows file.
+ - For non-Windows targets, use the M(stat) module instead.
+options:
+ path:
+ description:
+ - The full path of the file/object to get the facts of; both forward and
+ back slashes are accepted.
+ type: path
+ required: yes
+ aliases: [ dest, name ]
+ get_checksum:
+ description:
+ - Whether to return a checksum of the file (default sha1)
+ type: bool
+ default: yes
+ version_added: "2.1"
+ checksum_algorithm:
+ description:
+ - Algorithm to determine checksum of file.
+ - Will throw an error if the host is unable to use specified algorithm.
+ type: str
+ default: sha1
+ choices: [ md5, sha1, sha256, sha384, sha512 ]
+ version_added: "2.3"
+ follow:
+ description:
+ - Whether to follow symlinks or junction points.
+ - In the case of C(path) pointing to another link, then that will
+ be followed until no more links are found.
+ type: bool
+ default: no
+ version_added: "2.8"
+seealso:
+- module: stat
+- module: win_acl
+- module: win_file
+- module: win_owner
+author:
+- Chris Church (@cchurch)
+'''
+
+EXAMPLES = r'''
+- name: Obtain information about a file
+ win_stat:
+ path: C:\foo.ini
+ register: file_info
+
+- name: Obtain information about a folder
+ win_stat:
+ path: C:\bar
+ register: folder_info
+
+- name: Get MD5 checksum of a file
+ win_stat:
+ path: C:\foo.ini
+ get_checksum: yes
+ checksum_algorithm: md5
+ register: md5_checksum
+
+- debug:
+ var: md5_checksum.stat.checksum
+
+- name: Get SHA1 checksum of file
+ win_stat:
+ path: C:\foo.ini
+ get_checksum: yes
+ register: sha1_checksum
+
+- debug:
+ var: sha1_checksum.stat.checksum
+
+- name: Get SHA256 checksum of file
+ win_stat:
+ path: C:\foo.ini
+ get_checksum: yes
+ checksum_algorithm: sha256
+ register: sha256_checksum
+
+- debug:
+ var: sha256_checksum.stat.checksum
+'''
+
+RETURN = r'''
+changed:
+ description: Whether anything was changed
+ returned: always
+ type: bool
+ sample: true
+stat:
+ description: dictionary containing all the stat data
+ returned: success
+ type: complex
+ contains:
+ attributes:
+ description: Attributes of the file at path in raw form.
+ returned: success, path exists
+ type: str
+ sample: "Archive, Hidden"
+ checksum:
+ description: The checksum of a file based on checksum_algorithm specified.
+ returned: success, path exist, path is a file, get_checksum == True
+ checksum_algorithm specified is supported
+ type: str
+ sample: 09cb79e8fc7453c84a07f644e441fd81623b7f98
+ creationtime:
+ description: The create time of the file represented in seconds since epoch.
+ returned: success, path exists
+ type: float
+ sample: 1477984205.15
+ exists:
+ description: If the path exists or not.
+ returned: success
+ type: bool
+ sample: true
+ extension:
+ description: The extension of the file at path.
+ returned: success, path exists, path is a file
+ type: str
+ sample: ".ps1"
+ filename:
+ description: The name of the file (without path).
+ returned: success, path exists, path is a file
+ type: str
+ sample: foo.ini
+ hlnk_targets:
+ description: List of other files pointing to the same file (hard links), excludes the current file.
+ returned: success, path exists
+ type: list
+ sample:
+ - C:\temp\file.txt
+ - C:\Windows\update.log
+ isarchive:
+ description: If the path is ready for archiving or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isdir:
+ description: If the path is a directory or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ ishidden:
+ description: If the path is hidden or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isjunction:
+ description: If the path is a junction point or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ islnk:
+ description: If the path is a symbolic link or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isreadonly:
+ description: If the path is read only or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isreg:
+ description: If the path is a regular file.
+ returned: success, path exists
+ type: bool
+ sample: true
+ isshared:
+ description: If the path is shared or not.
+ returned: success, path exists
+ type: bool
+ sample: true
+ lastaccesstime:
+ description: The last access time of the file represented in seconds since epoch.
+ returned: success, path exists
+ type: float
+ sample: 1477984205.15
+ lastwritetime:
+ description: The last modification time of the file represented in seconds since epoch.
+ returned: success, path exists
+ type: float
+ sample: 1477984205.15
+ lnk_source:
+ description: Target of the symlink normalized for the remote filesystem.
+ returned: success, path exists and the path is a symbolic link or junction point
+ type: str
+ sample: C:\temp\link
+ lnk_target:
+ description: Target of the symlink. Note that relative paths remain relative.
+ returned: success, path exists and the path is a symbolic link or junction point
+ type: str
+ sample: ..\link
+ nlink:
+ description: Number of links to the file (hard links).
+ returned: success, path exists
+ type: int
+ sample: 1
+ owner:
+ description: The owner of the file.
+ returned: success, path exists
+ type: str
+ sample: BUILTIN\Administrators
+ path:
+ description: The full absolute path to the file.
+ returned: success, path exists, file exists
+ type: str
+ sample: C:\foo.ini
+ sharename:
+ description: The name of share if folder is shared.
+ returned: success, path exists, file is a directory and isshared == True
+ type: str
+ sample: file-share
+ size:
+ description: The size in bytes of a file or folder.
+ returned: success, path exists, file is not a link
+ type: int
+ sample: 1024
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_tempfile.ps1 b/test/support/windows-integration/plugins/modules/win_tempfile.ps1
new file mode 100644
index 00000000..9a1a7174
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_tempfile.ps1
@@ -0,0 +1,72 @@
+#!powershell
+
+# Copyright: (c) 2017, Dag Wieers (@dagwieers) <dag@wieers.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.Basic
+
+Function New-TempFile {
+ Param ([string]$path, [string]$prefix, [string]$suffix, [string]$type, [bool]$checkmode)
+ $temppath = $null
+ $curerror = $null
+ $attempt = 0
+
+ # Since we don't know if the file already exists, we try 5 times with a random name
+ do {
+ $attempt += 1
+ $randomname = [System.IO.Path]::GetRandomFileName()
+ $temppath = (Join-Path -Path $path -ChildPath "$prefix$randomname$suffix")
+ Try {
+ $file = New-Item -Path $temppath -ItemType $type -WhatIf:$checkmode
+ # Makes sure we get the full absolute path of the created temp file and not a relative or DOS 8.3 dir
+ if (-not $checkmode) {
+ $temppath = $file.FullName
+ } else {
+ # Just rely on GetFulLpath for check mode
+ $temppath = [System.IO.Path]::GetFullPath($temppath)
+ }
+ } Catch {
+ $temppath = $null
+ $curerror = $_
+ }
+ } until (($null -ne $temppath) -or ($attempt -ge 5))
+
+ # If it fails 5 times, something is wrong and we have to report the details
+ if ($null -eq $temppath) {
+ $module.FailJson("No random temporary file worked in $attempt attempts. Error: $($curerror.Exception.Message)", $curerror)
+ }
+
+ return $temppath.ToString()
+}
+
+$spec = @{
+ options = @{
+ path = @{ type='path'; default='%TEMP%'; aliases=@( 'dest' ) }
+ state = @{ type='str'; default='file'; choices=@( 'directory', 'file') }
+ prefix = @{ type='str'; default='ansible.' }
+ suffix = @{ type='str' }
+ }
+ supports_check_mode = $true
+}
+
+$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec)
+
+$path = $module.Params.path
+$state = $module.Params.state
+$prefix = $module.Params.prefix
+$suffix = $module.Params.suffix
+
+# Expand environment variables on non-path types
+if ($null -ne $prefix) {
+ $prefix = [System.Environment]::ExpandEnvironmentVariables($prefix)
+}
+if ($null -ne $suffix) {
+ $suffix = [System.Environment]::ExpandEnvironmentVariables($suffix)
+}
+
+$module.Result.changed = $true
+$module.Result.state = $state
+
+$module.Result.path = New-TempFile -Path $path -Prefix $prefix -Suffix $suffix -Type $state -CheckMode $module.CheckMode
+
+$module.ExitJson()
diff --git a/test/support/windows-integration/plugins/modules/win_tempfile.py b/test/support/windows-integration/plugins/modules/win_tempfile.py
new file mode 100644
index 00000000..58dd6501
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_tempfile.py
@@ -0,0 +1,67 @@
+#!/usr/bin/python
+# coding: utf-8 -*-
+
+# Copyright: (c) 2017, Dag Wieers <dag@wieers.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_tempfile
+version_added: "2.3"
+short_description: Creates temporary files and directories
+description:
+ - Creates temporary files and directories.
+ - For non-Windows targets, please use the M(tempfile) module instead.
+options:
+ state:
+ description:
+ - Whether to create file or directory.
+ type: str
+ choices: [ directory, file ]
+ default: file
+ path:
+ description:
+ - Location where temporary file or directory should be created.
+ - If path is not specified default system temporary directory (%TEMP%) will be used.
+ type: path
+ default: '%TEMP%'
+ aliases: [ dest ]
+ prefix:
+ description:
+ - Prefix of file/directory name created by module.
+ type: str
+ default: ansible.
+ suffix:
+ description:
+ - Suffix of file/directory name created by module.
+ type: str
+ default: ''
+seealso:
+- module: tempfile
+author:
+- Dag Wieers (@dagwieers)
+'''
+
+EXAMPLES = r"""
+- name: Create temporary build directory
+ win_tempfile:
+ state: directory
+ suffix: build
+
+- name: Create temporary file
+ win_tempfile:
+ state: file
+ suffix: temp
+"""
+
+RETURN = r'''
+path:
+ description: The absolute path to the created file or directory.
+ returned: success
+ type: str
+ sample: C:\Users\Administrator\AppData\Local\Temp\ansible.bMlvdk
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_template.py b/test/support/windows-integration/plugins/modules/win_template.py
new file mode 100644
index 00000000..bd8b2492
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_template.py
@@ -0,0 +1,66 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a virtual module that is entirely implemented server side
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_template
+version_added: "1.9.2"
+short_description: Template a file out to a remote server
+options:
+ backup:
+ description:
+ - Determine whether a backup should be created.
+ - When set to C(yes), create a backup file including the timestamp information
+ so you can get the original file back if you somehow clobbered it incorrectly.
+ type: bool
+ default: no
+ version_added: '2.8'
+ newline_sequence:
+ default: '\r\n'
+ force:
+ version_added: '2.4'
+notes:
+- Beware fetching files from windows machines when creating templates because certain tools, such as Powershell ISE,
+ and regedit's export facility add a Byte Order Mark as the first character of the file, which can cause tracebacks.
+- You can use the M(win_copy) module with the C(content:) option if you prefer the template inline, as part of the
+ playbook.
+- For Linux you can use M(template) which uses '\\n' as C(newline_sequence) by default.
+seealso:
+- module: win_copy
+- module: copy
+- module: template
+author:
+- Jon Hawkesworth (@jhawkesworth)
+extends_documentation_fragment:
+- template_common
+'''
+
+EXAMPLES = r'''
+- name: Create a file from a Jinja2 template
+ win_template:
+ src: /mytemplates/file.conf.j2
+ dest: C:\Temp\file.conf
+
+- name: Create a Unix-style file from a Jinja2 template
+ win_template:
+ src: unix/config.conf.j2
+ dest: C:\share\unix\config.conf
+ newline_sequence: '\n'
+ backup: yes
+'''
+
+RETURN = r'''
+backup_file:
+ description: Name of the backup file that was created.
+ returned: if backup=yes
+ type: str
+ sample: C:\Path\To\File.txt.11540.20150212-220915.bak
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_user.ps1 b/test/support/windows-integration/plugins/modules/win_user.ps1
new file mode 100644
index 00000000..54905cb2
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_user.ps1
@@ -0,0 +1,273 @@
+#!powershell
+
+# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#AnsibleRequires -CSharpUtil Ansible.AccessToken
+#Requires -Module Ansible.ModuleUtils.Legacy
+
+########
+$ADS_UF_PASSWD_CANT_CHANGE = 64
+$ADS_UF_DONT_EXPIRE_PASSWD = 65536
+
+$adsi = [ADSI]"WinNT://$env:COMPUTERNAME"
+
+function Get-User($user) {
+ $adsi.Children | Where-Object {$_.SchemaClassName -eq 'user' -and $_.Name -eq $user }
+ return
+}
+
+function Get-UserFlag($user, $flag) {
+ If ($user.UserFlags[0] -band $flag) {
+ $true
+ }
+ Else {
+ $false
+ }
+}
+
+function Set-UserFlag($user, $flag) {
+ $user.UserFlags = ($user.UserFlags[0] -BOR $flag)
+}
+
+function Clear-UserFlag($user, $flag) {
+ $user.UserFlags = ($user.UserFlags[0] -BXOR $flag)
+}
+
+function Get-Group($grp) {
+ $adsi.Children | Where-Object { $_.SchemaClassName -eq 'Group' -and $_.Name -eq $grp }
+ return
+}
+
+Function Test-LocalCredential {
+ param([String]$Username, [String]$Password)
+
+ try {
+ $handle = [Ansible.AccessToken.TokenUtil]::LogonUser($Username, $null, $Password, "Network", "Default")
+ $handle.Dispose()
+ $valid_credentials = $true
+ } catch [Ansible.AccessToken.Win32Exception] {
+ # following errors indicate the creds are correct but the user was
+ # unable to log on for other reasons, which we don't care about
+ $success_codes = @(
+ 0x0000052F, # ERROR_ACCOUNT_RESTRICTION
+ 0x00000530, # ERROR_INVALID_LOGON_HOURS
+ 0x00000531, # ERROR_INVALID_WORKSTATION
+ 0x00000569 # ERROR_LOGON_TYPE_GRANTED
+ )
+
+ if ($_.Exception.NativeErrorCode -eq 0x0000052E) {
+ # ERROR_LOGON_FAILURE - the user or pass was incorrect
+ $valid_credentials = $false
+ } elseif ($_.Exception.NativeErrorCode -in $success_codes) {
+ $valid_credentials = $true
+ } else {
+ # an unknown failure, reraise exception
+ throw $_
+ }
+ }
+ return $valid_credentials
+}
+
+########
+
+$params = Parse-Args $args;
+
+$result = @{
+ changed = $false
+};
+
+$username = Get-AnsibleParam -obj $params -name "name" -type "str" -failifempty $true
+$fullname = Get-AnsibleParam -obj $params -name "fullname" -type "str"
+$description = Get-AnsibleParam -obj $params -name "description" -type "str"
+$password = Get-AnsibleParam -obj $params -name "password" -type "str"
+$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent","query"
+$update_password = Get-AnsibleParam -obj $params -name "update_password" -type "str" -default "always" -validateset "always","on_create"
+$password_expired = Get-AnsibleParam -obj $params -name "password_expired" -type "bool"
+$password_never_expires = Get-AnsibleParam -obj $params -name "password_never_expires" -type "bool"
+$user_cannot_change_password = Get-AnsibleParam -obj $params -name "user_cannot_change_password" -type "bool"
+$account_disabled = Get-AnsibleParam -obj $params -name "account_disabled" -type "bool"
+$account_locked = Get-AnsibleParam -obj $params -name "account_locked" -type "bool"
+$groups = Get-AnsibleParam -obj $params -name "groups"
+$groups_action = Get-AnsibleParam -obj $params -name "groups_action" -type "str" -default "replace" -validateset "add","remove","replace"
+
+If ($null -ne $account_locked -and $account_locked) {
+ Fail-Json $result "account_locked must be set to 'no' if provided"
+}
+
+If ($null -ne $groups) {
+ If ($groups -is [System.String]) {
+ [string[]]$groups = $groups.Split(",")
+ }
+ ElseIf ($groups -isnot [System.Collections.IList]) {
+ Fail-Json $result "groups must be a string or array"
+ }
+ $groups = $groups | ForEach-Object { ([string]$_).Trim() } | Where-Object { $_ }
+ If ($null -eq $groups) {
+ $groups = @()
+ }
+}
+
+$user_obj = Get-User $username
+
+If ($state -eq 'present') {
+ # Add or update user
+ try {
+ If (-not $user_obj) {
+ $user_obj = $adsi.Create("User", $username)
+ If ($null -ne $password) {
+ $user_obj.SetPassword($password)
+ }
+ $user_obj.SetInfo()
+ $result.changed = $true
+ }
+ ElseIf (($null -ne $password) -and ($update_password -eq 'always')) {
+ # ValidateCredentials will fail if either of these are true- just force update...
+ If($user_obj.AccountDisabled -or $user_obj.PasswordExpired) {
+ $password_match = $false
+ }
+ Else {
+ try {
+ $password_match = Test-LocalCredential -Username $username -Password $password
+ } catch [System.ComponentModel.Win32Exception] {
+ Fail-Json -obj $result -message "Failed to validate the user's credentials: $($_.Exception.Message)"
+ }
+ }
+
+ If (-not $password_match) {
+ $user_obj.SetPassword($password)
+ $result.changed = $true
+ }
+ }
+ If (($null -ne $fullname) -and ($fullname -ne $user_obj.FullName[0])) {
+ $user_obj.FullName = $fullname
+ $result.changed = $true
+ }
+ If (($null -ne $description) -and ($description -ne $user_obj.Description[0])) {
+ $user_obj.Description = $description
+ $result.changed = $true
+ }
+ If (($null -ne $password_expired) -and ($password_expired -ne ($user_obj.PasswordExpired | ConvertTo-Bool))) {
+ $user_obj.PasswordExpired = If ($password_expired) { 1 } Else { 0 }
+ $result.changed = $true
+ }
+ If (($null -ne $password_never_expires) -and ($password_never_expires -ne (Get-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD))) {
+ If ($password_never_expires) {
+ Set-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD
+ }
+ Else {
+ Clear-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD
+ }
+ $result.changed = $true
+ }
+ If (($null -ne $user_cannot_change_password) -and ($user_cannot_change_password -ne (Get-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE))) {
+ If ($user_cannot_change_password) {
+ Set-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE
+ }
+ Else {
+ Clear-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE
+ }
+ $result.changed = $true
+ }
+ If (($null -ne $account_disabled) -and ($account_disabled -ne $user_obj.AccountDisabled)) {
+ $user_obj.AccountDisabled = $account_disabled
+ $result.changed = $true
+ }
+ If (($null -ne $account_locked) -and ($account_locked -ne $user_obj.IsAccountLocked)) {
+ $user_obj.IsAccountLocked = $account_locked
+ $result.changed = $true
+ }
+ If ($result.changed) {
+ $user_obj.SetInfo()
+ }
+ If ($null -ne $groups) {
+ [string[]]$current_groups = $user_obj.Groups() | ForEach-Object { $_.GetType().InvokeMember("Name", "GetProperty", $null, $_, $null) }
+ If (($groups_action -eq "remove") -or ($groups_action -eq "replace")) {
+ ForEach ($grp in $current_groups) {
+ If ((($groups_action -eq "remove") -and ($groups -contains $grp)) -or (($groups_action -eq "replace") -and ($groups -notcontains $grp))) {
+ $group_obj = Get-Group $grp
+ If ($group_obj) {
+ $group_obj.Remove($user_obj.Path)
+ $result.changed = $true
+ }
+ Else {
+ Fail-Json $result "group '$grp' not found"
+ }
+ }
+ }
+ }
+ If (($groups_action -eq "add") -or ($groups_action -eq "replace")) {
+ ForEach ($grp in $groups) {
+ If ($current_groups -notcontains $grp) {
+ $group_obj = Get-Group $grp
+ If ($group_obj) {
+ $group_obj.Add($user_obj.Path)
+ $result.changed = $true
+ }
+ Else {
+ Fail-Json $result "group '$grp' not found"
+ }
+ }
+ }
+ }
+ }
+ }
+ catch {
+ Fail-Json $result $_.Exception.Message
+ }
+}
+ElseIf ($state -eq 'absent') {
+ # Remove user
+ try {
+ If ($user_obj) {
+ $username = $user_obj.Name.Value
+ $adsi.delete("User", $user_obj.Name.Value)
+ $result.changed = $true
+ $result.msg = "User '$username' deleted successfully"
+ $user_obj = $null
+ } else {
+ $result.msg = "User '$username' was not found"
+ }
+ }
+ catch {
+ Fail-Json $result $_.Exception.Message
+ }
+}
+
+try {
+ If ($user_obj -and $user_obj -is [System.DirectoryServices.DirectoryEntry]) {
+ $user_obj.RefreshCache()
+ $result.name = $user_obj.Name[0]
+ $result.fullname = $user_obj.FullName[0]
+ $result.path = $user_obj.Path
+ $result.description = $user_obj.Description[0]
+ $result.password_expired = ($user_obj.PasswordExpired | ConvertTo-Bool)
+ $result.password_never_expires = (Get-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD)
+ $result.user_cannot_change_password = (Get-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE)
+ $result.account_disabled = $user_obj.AccountDisabled
+ $result.account_locked = $user_obj.IsAccountLocked
+ $result.sid = (New-Object System.Security.Principal.SecurityIdentifier($user_obj.ObjectSid.Value, 0)).Value
+ $user_groups = @()
+ ForEach ($grp in $user_obj.Groups()) {
+ $group_result = @{
+ name = $grp.GetType().InvokeMember("Name", "GetProperty", $null, $grp, $null)
+ path = $grp.GetType().InvokeMember("ADsPath", "GetProperty", $null, $grp, $null)
+ }
+ $user_groups += $group_result;
+ }
+ $result.groups = $user_groups
+ $result.state = "present"
+ }
+ Else {
+ $result.name = $username
+ if ($state -eq 'query') {
+ $result.msg = "User '$username' was not found"
+ }
+ $result.state = "absent"
+ }
+}
+catch {
+ Fail-Json $result $_.Exception.Message
+}
+
+Exit-Json $result
diff --git a/test/support/windows-integration/plugins/modules/win_user.py b/test/support/windows-integration/plugins/modules/win_user.py
new file mode 100644
index 00000000..5fc0633d
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_user.py
@@ -0,0 +1,194 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2014, Matt Martz <matt@sivel.net>, and others
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub. actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['stableinterface'],
+ 'supported_by': 'core'}
+
+DOCUMENTATION = r'''
+---
+module: win_user
+version_added: "1.7"
+short_description: Manages local Windows user accounts
+description:
+ - Manages local Windows user accounts.
+ - For non-Windows targets, use the M(user) module instead.
+options:
+ name:
+ description:
+ - Name of the user to create, remove or modify.
+ type: str
+ required: yes
+ fullname:
+ description:
+ - Full name of the user.
+ type: str
+ version_added: "1.9"
+ description:
+ description:
+ - Description of the user.
+ type: str
+ version_added: "1.9"
+ password:
+ description:
+ - Optionally set the user's password to this (plain text) value.
+ type: str
+ update_password:
+ description:
+ - C(always) will update passwords if they differ. C(on_create) will
+ only set the password for newly created users.
+ type: str
+ choices: [ always, on_create ]
+ default: always
+ version_added: "1.9"
+ password_expired:
+ description:
+ - C(yes) will require the user to change their password at next login.
+ - C(no) will clear the expired password flag.
+ type: bool
+ version_added: "1.9"
+ password_never_expires:
+ description:
+ - C(yes) will set the password to never expire.
+ - C(no) will allow the password to expire.
+ type: bool
+ version_added: "1.9"
+ user_cannot_change_password:
+ description:
+ - C(yes) will prevent the user from changing their password.
+ - C(no) will allow the user to change their password.
+ type: bool
+ version_added: "1.9"
+ account_disabled:
+ description:
+ - C(yes) will disable the user account.
+ - C(no) will clear the disabled flag.
+ type: bool
+ version_added: "1.9"
+ account_locked:
+ description:
+ - C(no) will unlock the user account if locked.
+ choices: [ 'no' ]
+ version_added: "1.9"
+ groups:
+ description:
+ - Adds or removes the user from this comma-separated list of groups,
+ depending on the value of I(groups_action).
+ - When I(groups_action) is C(replace) and I(groups) is set to the empty
+ string ('groups='), the user is removed from all groups.
+ version_added: "1.9"
+ groups_action:
+ description:
+ - If C(add), the user is added to each group in I(groups) where not
+ already a member.
+ - If C(replace), the user is added as a member of each group in
+ I(groups) and removed from any other groups.
+ - If C(remove), the user is removed from each group in I(groups).
+ type: str
+ choices: [ add, replace, remove ]
+ default: replace
+ version_added: "1.9"
+ state:
+ description:
+ - When C(absent), removes the user account if it exists.
+ - When C(present), creates or updates the user account.
+ - When C(query) (new in 1.9), retrieves the user account details
+ without making any changes.
+ type: str
+ choices: [ absent, present, query ]
+ default: present
+seealso:
+- module: user
+- module: win_domain_membership
+- module: win_domain_user
+- module: win_group
+- module: win_group_membership
+- module: win_user_profile
+author:
+ - Paul Durivage (@angstwad)
+ - Chris Church (@cchurch)
+'''
+
+EXAMPLES = r'''
+- name: Ensure user bob is present
+ win_user:
+ name: bob
+ password: B0bP4ssw0rd
+ state: present
+ groups:
+ - Users
+
+- name: Ensure user bob is absent
+ win_user:
+ name: bob
+ state: absent
+'''
+
+RETURN = r'''
+account_disabled:
+ description: Whether the user is disabled.
+ returned: user exists
+ type: bool
+ sample: false
+account_locked:
+ description: Whether the user is locked.
+ returned: user exists
+ type: bool
+ sample: false
+description:
+ description: The description set for the user.
+ returned: user exists
+ type: str
+ sample: Username for test
+fullname:
+ description: The full name set for the user.
+ returned: user exists
+ type: str
+ sample: Test Username
+groups:
+ description: A list of groups and their ADSI path the user is a member of.
+ returned: user exists
+ type: list
+ sample: [
+ {
+ "name": "Administrators",
+ "path": "WinNT://WORKGROUP/USER-PC/Administrators"
+ }
+ ]
+name:
+ description: The name of the user
+ returned: always
+ type: str
+ sample: username
+password_expired:
+ description: Whether the password is expired.
+ returned: user exists
+ type: bool
+ sample: false
+password_never_expires:
+ description: Whether the password is set to never expire.
+ returned: user exists
+ type: bool
+ sample: true
+path:
+ description: The ADSI path for the user.
+ returned: user exists
+ type: str
+ sample: "WinNT://WORKGROUP/USER-PC/username"
+sid:
+ description: The SID for the user.
+ returned: user exists
+ type: str
+ sample: S-1-5-21-3322259488-2828151810-3939402796-1001
+user_cannot_change_password:
+ description: Whether the user can change their own password.
+ returned: user exists
+ type: bool
+ sample: false
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_user_right.ps1 b/test/support/windows-integration/plugins/modules/win_user_right.ps1
new file mode 100644
index 00000000..3fac52a8
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_user_right.ps1
@@ -0,0 +1,349 @@
+#!powershell
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+#Requires -Module Ansible.ModuleUtils.SID
+
+$ErrorActionPreference = 'Stop'
+
+$params = Parse-Args $args -supports_check_mode $true
+$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false
+$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false
+$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP
+
+$name = Get-AnsibleParam -obj $params -name "name" -type "str" -failifempty $true
+$users = Get-AnsibleParam -obj $params -name "users" -type "list" -failifempty $true
+$action = Get-AnsibleParam -obj $params -name "action" -type "str" -default "set" -validateset "add","remove","set"
+
+$result = @{
+ changed = $false
+ added = @()
+ removed = @()
+}
+
+if ($diff_mode) {
+ $result.diff = @{}
+}
+
+$sec_helper_util = @"
+using System;
+using System.ComponentModel;
+using System.Runtime.InteropServices;
+using System.Security.Principal;
+
+namespace Ansible
+{
+ public class LsaRightHelper : IDisposable
+ {
+ // Code modified from https://gallery.technet.microsoft.com/scriptcenter/Grant-Revoke-Query-user-26e259b0
+
+ enum Access : int
+ {
+ POLICY_READ = 0x20006,
+ POLICY_ALL_ACCESS = 0x00F0FFF,
+ POLICY_EXECUTE = 0X20801,
+ POLICY_WRITE = 0X207F8
+ }
+
+ IntPtr lsaHandle;
+
+ const string LSA_DLL = "advapi32.dll";
+ const CharSet DEFAULT_CHAR_SET = CharSet.Unicode;
+
+ const uint STATUS_NO_MORE_ENTRIES = 0x8000001a;
+ const uint STATUS_NO_SUCH_PRIVILEGE = 0xc0000060;
+
+ internal sealed class Sid : IDisposable
+ {
+ public IntPtr pSid = IntPtr.Zero;
+ public SecurityIdentifier sid = null;
+
+ public Sid(string sidString)
+ {
+ try
+ {
+ sid = new SecurityIdentifier(sidString);
+ } catch
+ {
+ throw new ArgumentException(String.Format("SID string {0} could not be converted to SecurityIdentifier", sidString));
+ }
+
+ Byte[] buffer = new Byte[sid.BinaryLength];
+ sid.GetBinaryForm(buffer, 0);
+
+ pSid = Marshal.AllocHGlobal(sid.BinaryLength);
+ Marshal.Copy(buffer, 0, pSid, sid.BinaryLength);
+ }
+
+ public void Dispose()
+ {
+ if (pSid != IntPtr.Zero)
+ {
+ Marshal.FreeHGlobal(pSid);
+ pSid = IntPtr.Zero;
+ }
+ GC.SuppressFinalize(this);
+ }
+ ~Sid() { Dispose(); }
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ private struct LSA_OBJECT_ATTRIBUTES
+ {
+ public int Length;
+ public IntPtr RootDirectory;
+ public IntPtr ObjectName;
+ public int Attributes;
+ public IntPtr SecurityDescriptor;
+ public IntPtr SecurityQualityOfService;
+ }
+
+ [StructLayout(LayoutKind.Sequential, CharSet = DEFAULT_CHAR_SET)]
+ private struct LSA_UNICODE_STRING
+ {
+ public ushort Length;
+ public ushort MaximumLength;
+ [MarshalAs(UnmanagedType.LPWStr)]
+ public string Buffer;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ private struct LSA_ENUMERATION_INFORMATION
+ {
+ public IntPtr Sid;
+ }
+
+ [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)]
+ private static extern uint LsaOpenPolicy(
+ LSA_UNICODE_STRING[] SystemName,
+ ref LSA_OBJECT_ATTRIBUTES ObjectAttributes,
+ int AccessMask,
+ out IntPtr PolicyHandle
+ );
+
+ [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)]
+ private static extern uint LsaAddAccountRights(
+ IntPtr PolicyHandle,
+ IntPtr pSID,
+ LSA_UNICODE_STRING[] UserRights,
+ int CountOfRights
+ );
+
+ [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)]
+ private static extern uint LsaRemoveAccountRights(
+ IntPtr PolicyHandle,
+ IntPtr pSID,
+ bool AllRights,
+ LSA_UNICODE_STRING[] UserRights,
+ int CountOfRights
+ );
+
+ [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)]
+ private static extern uint LsaEnumerateAccountsWithUserRight(
+ IntPtr PolicyHandle,
+ LSA_UNICODE_STRING[] UserRights,
+ out IntPtr EnumerationBuffer,
+ out ulong CountReturned
+ );
+
+ [DllImport(LSA_DLL)]
+ private static extern int LsaNtStatusToWinError(int NTSTATUS);
+
+ [DllImport(LSA_DLL)]
+ private static extern int LsaClose(IntPtr PolicyHandle);
+
+ [DllImport(LSA_DLL)]
+ private static extern int LsaFreeMemory(IntPtr Buffer);
+
+ public LsaRightHelper()
+ {
+ LSA_OBJECT_ATTRIBUTES lsaAttr;
+ lsaAttr.RootDirectory = IntPtr.Zero;
+ lsaAttr.ObjectName = IntPtr.Zero;
+ lsaAttr.Attributes = 0;
+ lsaAttr.SecurityDescriptor = IntPtr.Zero;
+ lsaAttr.SecurityQualityOfService = IntPtr.Zero;
+ lsaAttr.Length = Marshal.SizeOf(typeof(LSA_OBJECT_ATTRIBUTES));
+
+ lsaHandle = IntPtr.Zero;
+
+ LSA_UNICODE_STRING[] system = new LSA_UNICODE_STRING[1];
+ system[0] = InitLsaString("");
+
+ uint ret = LsaOpenPolicy(system, ref lsaAttr, (int)Access.POLICY_ALL_ACCESS, out lsaHandle);
+ if (ret != 0)
+ throw new Win32Exception(LsaNtStatusToWinError((int)ret));
+ }
+
+ public void AddPrivilege(string sidString, string privilege)
+ {
+ uint ret = 0;
+ using (Sid sid = new Sid(sidString))
+ {
+ LSA_UNICODE_STRING[] privileges = new LSA_UNICODE_STRING[1];
+ privileges[0] = InitLsaString(privilege);
+ ret = LsaAddAccountRights(lsaHandle, sid.pSid, privileges, 1);
+ }
+ if (ret != 0)
+ throw new Win32Exception(LsaNtStatusToWinError((int)ret));
+ }
+
+ public void RemovePrivilege(string sidString, string privilege)
+ {
+ uint ret = 0;
+ using (Sid sid = new Sid(sidString))
+ {
+ LSA_UNICODE_STRING[] privileges = new LSA_UNICODE_STRING[1];
+ privileges[0] = InitLsaString(privilege);
+ ret = LsaRemoveAccountRights(lsaHandle, sid.pSid, false, privileges, 1);
+ }
+ if (ret != 0)
+ throw new Win32Exception(LsaNtStatusToWinError((int)ret));
+ }
+
+ public string[] EnumerateAccountsWithUserRight(string privilege)
+ {
+ uint ret = 0;
+ ulong count = 0;
+ LSA_UNICODE_STRING[] rights = new LSA_UNICODE_STRING[1];
+ rights[0] = InitLsaString(privilege);
+ IntPtr buffer = IntPtr.Zero;
+
+ ret = LsaEnumerateAccountsWithUserRight(lsaHandle, rights, out buffer, out count);
+ switch (ret)
+ {
+ case 0:
+ string[] accounts = new string[count];
+ for (int i = 0; i < (int)count; i++)
+ {
+ LSA_ENUMERATION_INFORMATION LsaInfo = (LSA_ENUMERATION_INFORMATION)Marshal.PtrToStructure(
+ IntPtr.Add(buffer, i * Marshal.SizeOf(typeof(LSA_ENUMERATION_INFORMATION))),
+ typeof(LSA_ENUMERATION_INFORMATION));
+
+ accounts[i] = new SecurityIdentifier(LsaInfo.Sid).ToString();
+ }
+ LsaFreeMemory(buffer);
+ return accounts;
+
+ case STATUS_NO_MORE_ENTRIES:
+ return new string[0];
+
+ case STATUS_NO_SUCH_PRIVILEGE:
+ throw new ArgumentException(String.Format("Invalid privilege {0} not found in LSA database", privilege));
+
+ default:
+ throw new Win32Exception(LsaNtStatusToWinError((int)ret));
+ }
+ }
+
+ static LSA_UNICODE_STRING InitLsaString(string s)
+ {
+ // Unicode strings max. 32KB
+ if (s.Length > 0x7ffe)
+ throw new ArgumentException("String too long");
+
+ LSA_UNICODE_STRING lus = new LSA_UNICODE_STRING();
+ lus.Buffer = s;
+ lus.Length = (ushort)(s.Length * sizeof(char));
+ lus.MaximumLength = (ushort)(lus.Length + sizeof(char));
+
+ return lus;
+ }
+
+ public void Dispose()
+ {
+ if (lsaHandle != IntPtr.Zero)
+ {
+ LsaClose(lsaHandle);
+ lsaHandle = IntPtr.Zero;
+ }
+ GC.SuppressFinalize(this);
+ }
+ ~LsaRightHelper() { Dispose(); }
+ }
+}
+"@
+
+$original_tmp = $env:TMP
+$env:TMP = $_remote_tmp
+Add-Type -TypeDefinition $sec_helper_util
+$env:TMP = $original_tmp
+
+Function Compare-UserList($existing_users, $new_users) {
+ $added_users = [String[]]@()
+ $removed_users = [String[]]@()
+ if ($action -eq "add") {
+ $added_users = [Linq.Enumerable]::Except($new_users, $existing_users)
+ } elseif ($action -eq "remove") {
+ $removed_users = [Linq.Enumerable]::Intersect($new_users, $existing_users)
+ } else {
+ $added_users = [Linq.Enumerable]::Except($new_users, $existing_users)
+ $removed_users = [Linq.Enumerable]::Except($existing_users, $new_users)
+ }
+
+ $change_result = @{
+ added = $added_users
+ removed = $removed_users
+ }
+
+ return $change_result
+}
+
+# C# class we can use to enumerate/add/remove rights
+$lsa_helper = New-Object -TypeName Ansible.LsaRightHelper
+
+$new_users = [System.Collections.ArrayList]@()
+foreach ($user in $users) {
+ $user_sid = Convert-ToSID -account_name $user
+ $new_users.Add($user_sid) > $null
+}
+$new_users = [String[]]$new_users.ToArray()
+try {
+ $existing_users = $lsa_helper.EnumerateAccountsWithUserRight($name)
+} catch [ArgumentException] {
+ Fail-Json -obj $result -message "the specified right $name is not a valid right"
+} catch {
+ Fail-Json -obj $result -message "failed to enumerate existing accounts with right: $($_.Exception.Message)"
+}
+
+$change_result = Compare-UserList -existing_users $existing_users -new_user $new_users
+if (($change_result.added.Length -gt 0) -or ($change_result.removed.Length -gt 0)) {
+ $result.changed = $true
+ $diff_text = "[$name]`n"
+
+ # used in diff mode calculation
+ $new_user_list = [System.Collections.ArrayList]$existing_users
+ foreach ($user in $change_result.removed) {
+ if (-not $check_mode) {
+ $lsa_helper.RemovePrivilege($user, $name)
+ }
+ $user_name = Convert-FromSID -sid $user
+ $result.removed += $user_name
+ $diff_text += "-$user_name`n"
+ $new_user_list.Remove($user) > $null
+ }
+ foreach ($user in $change_result.added) {
+ if (-not $check_mode) {
+ $lsa_helper.AddPrivilege($user, $name)
+ }
+ $user_name = Convert-FromSID -sid $user
+ $result.added += $user_name
+ $diff_text += "+$user_name`n"
+ $new_user_list.Add($user) > $null
+ }
+
+ if ($diff_mode) {
+ if ($new_user_list.Count -eq 0) {
+ $diff_text = "-$diff_text"
+ } else {
+ if ($existing_users.Count -eq 0) {
+ $diff_text = "+$diff_text"
+ }
+ }
+ $result.diff.prepared = $diff_text
+ }
+}
+
+Exit-Json $result
diff --git a/test/support/windows-integration/plugins/modules/win_user_right.py b/test/support/windows-integration/plugins/modules/win_user_right.py
new file mode 100644
index 00000000..55882083
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_user_right.py
@@ -0,0 +1,108 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub. actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_user_right
+version_added: '2.4'
+short_description: Manage Windows User Rights
+description:
+- Add, remove or set User Rights for a group or users or groups.
+- You can set user rights for both local and domain accounts.
+options:
+ name:
+ description:
+ - The name of the User Right as shown by the C(Constant Name) value from
+ U(https://technet.microsoft.com/en-us/library/dd349804.aspx).
+ - The module will return an error if the right is invalid.
+ type: str
+ required: yes
+ users:
+ description:
+ - A list of users or groups to add/remove on the User Right.
+ - These can be in the form DOMAIN\user-group, user-group@DOMAIN.COM for
+ domain users/groups.
+ - For local users/groups it can be in the form user-group, .\user-group,
+ SERVERNAME\user-group where SERVERNAME is the name of the remote server.
+ - You can also add special local accounts like SYSTEM and others.
+ - Can be set to an empty list with I(action=set) to remove all accounts
+ from the right.
+ type: list
+ required: yes
+ action:
+ description:
+ - C(add) will add the users/groups to the existing right.
+ - C(remove) will remove the users/groups from the existing right.
+ - C(set) will replace the users/groups of the existing right.
+ type: str
+ default: set
+ choices: [ add, remove, set ]
+notes:
+- If the server is domain joined this module can change a right but if a GPO
+ governs this right then the changes won't last.
+seealso:
+- module: win_group
+- module: win_group_membership
+- module: win_user
+author:
+- Jordan Borean (@jborean93)
+'''
+
+EXAMPLES = r'''
+---
+- name: Replace the entries of Deny log on locally
+ win_user_right:
+ name: SeDenyInteractiveLogonRight
+ users:
+ - Guest
+ - Users
+ action: set
+
+- name: Add account to Log on as a service
+ win_user_right:
+ name: SeServiceLogonRight
+ users:
+ - .\Administrator
+ - '{{ansible_hostname}}\local-user'
+ action: add
+
+- name: Remove accounts who can create Symbolic links
+ win_user_right:
+ name: SeCreateSymbolicLinkPrivilege
+ users:
+ - SYSTEM
+ - Administrators
+ - DOMAIN\User
+ - group@DOMAIN.COM
+ action: remove
+
+- name: Remove all accounts who cannot log on remote interactively
+ win_user_right:
+ name: SeDenyRemoteInteractiveLogonRight
+ users: []
+'''
+
+RETURN = r'''
+added:
+ description: A list of accounts that were added to the right, this is empty
+ if no accounts were added.
+ returned: success
+ type: list
+ sample: ["NT AUTHORITY\\SYSTEM", "DOMAIN\\User"]
+removed:
+ description: A list of accounts that were removed from the right, this is
+ empty if no accounts were removed.
+ returned: success
+ type: list
+ sample: ["SERVERNAME\\Administrator", "BUILTIN\\Administrators"]
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_wait_for.ps1 b/test/support/windows-integration/plugins/modules/win_wait_for.ps1
new file mode 100644
index 00000000..e0a9a720
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_wait_for.ps1
@@ -0,0 +1,259 @@
+#!powershell
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+#Requires -Module Ansible.ModuleUtils.FileUtil
+
+$ErrorActionPreference = "Stop"
+
+$params = Parse-Args -arguments $args -supports_check_mode $true
+
+$connect_timeout = Get-AnsibleParam -obj $params -name "connect_timeout" -type "int" -default 5
+$delay = Get-AnsibleParam -obj $params -name "delay" -type "int"
+$exclude_hosts = Get-AnsibleParam -obj $params -name "exclude_hosts" -type "list"
+$hostname = Get-AnsibleParam -obj $params -name "host" -type "str" -default "127.0.0.1"
+$path = Get-AnsibleParam -obj $params -name "path" -type "path"
+$port = Get-AnsibleParam -obj $params -name "port" -type "int"
+$regex = Get-AnsibleParam -obj $params -name "regex" -type "str" -aliases "search_regex","regexp"
+$sleep = Get-AnsibleParam -obj $params -name "sleep" -type "int" -default 1
+$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "started" -validateset "present","started","stopped","absent","drained"
+$timeout = Get-AnsibleParam -obj $params -name "timeout" -type "int" -default 300
+
+$result = @{
+ changed = $false
+ elapsed = 0
+}
+
+# validate the input with the various options
+if ($null -ne $port -and $null -ne $path) {
+ Fail-Json $result "port and path parameter can not both be passed to win_wait_for"
+}
+if ($null -ne $exclude_hosts -and $state -ne "drained") {
+ Fail-Json $result "exclude_hosts should only be with state=drained"
+}
+if ($null -ne $path) {
+ if ($state -in @("stopped","drained")) {
+ Fail-Json $result "state=$state should only be used for checking a port in the win_wait_for module"
+ }
+
+ if ($null -ne $exclude_hosts) {
+ Fail-Json $result "exclude_hosts should only be used when checking a port and state=drained in the win_wait_for module"
+ }
+}
+
+if ($null -ne $port) {
+ if ($null -ne $regex) {
+ Fail-Json $result "regex should by used when checking a string in a file in the win_wait_for module"
+ }
+
+ if ($null -ne $exclude_hosts -and $state -ne "drained") {
+ Fail-Json $result "exclude_hosts should be used when state=drained in the win_wait_for module"
+ }
+}
+
+Function Test-Port($hostname, $port) {
+ $timeout = $connect_timeout * 1000
+ $socket = New-Object -TypeName System.Net.Sockets.TcpClient
+ $connect = $socket.BeginConnect($hostname, $port, $null, $null)
+ $wait = $connect.AsyncWaitHandle.WaitOne($timeout, $false)
+
+ if ($wait) {
+ try {
+ $socket.EndConnect($connect) | Out-Null
+ $valid = $true
+ } catch {
+ $valid = $false
+ }
+ } else {
+ $valid = $false
+ }
+
+ $socket.Close()
+ $socket.Dispose()
+
+ $valid
+}
+
+Function Get-PortConnections($hostname, $port) {
+ $connections = @()
+
+ $conn_info = [Net.NetworkInformation.IPGlobalProperties]::GetIPGlobalProperties()
+ if ($hostname -eq "0.0.0.0") {
+ $active_connections = $conn_info.GetActiveTcpConnections() | Where-Object { $_.LocalEndPoint.Port -eq $port }
+ } else {
+ $active_connections = $conn_info.GetActiveTcpConnections() | Where-Object { $_.LocalEndPoint.Address -eq $hostname -and $_.LocalEndPoint.Port -eq $port }
+ }
+
+ if ($null -ne $active_connections) {
+ foreach ($active_connection in $active_connections) {
+ $connections += $active_connection.RemoteEndPoint.Address
+ }
+ }
+
+ $connections
+}
+
+$module_start = Get-Date
+
+if ($null -ne $delay) {
+ Start-Sleep -Seconds $delay
+}
+
+$attempts = 0
+if ($null -eq $path -and $null -eq $port -and $state -ne "drained") {
+ Start-Sleep -Seconds $timeout
+} elseif ($null -ne $path) {
+ if ($state -in @("present", "started")) {
+ # check if the file exists or string exists in file
+ $start_time = Get-Date
+ $complete = $false
+ while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) {
+ $attempts += 1
+ if (Test-AnsiblePath -Path $path) {
+ if ($null -eq $regex) {
+ $complete = $true
+ break
+ } else {
+ $file_contents = Get-Content -Path $path -Raw
+ if ($file_contents -match $regex) {
+ $complete = $true
+ break
+ }
+ }
+ }
+ Start-Sleep -Seconds $sleep
+ }
+
+ if ($complete -eq $false) {
+ $result.elapsed = ((Get-Date) - $module_start).TotalSeconds
+ $result.wait_attempts = $attempts
+ if ($null -eq $regex) {
+ Fail-Json $result "timeout while waiting for file $path to be present"
+ } else {
+ Fail-Json $result "timeout while waiting for string regex $regex in file $path to match"
+ }
+ }
+ } elseif ($state -in @("absent")) {
+ # check if the file is deleted or string doesn't exist in file
+ $start_time = Get-Date
+ $complete = $false
+ while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) {
+ $attempts += 1
+ if (Test-AnsiblePath -Path $path) {
+ if ($null -ne $regex) {
+ $file_contents = Get-Content -Path $path -Raw
+ if ($file_contents -notmatch $regex) {
+ $complete = $true
+ break
+ }
+ }
+ } else {
+ $complete = $true
+ break
+ }
+
+ Start-Sleep -Seconds $sleep
+ }
+
+ if ($complete -eq $false) {
+ $result.elapsed = ((Get-Date) - $module_start).TotalSeconds
+ $result.wait_attempts = $attempts
+ if ($null -eq $regex) {
+ Fail-Json $result "timeout while waiting for file $path to be absent"
+ } else {
+ Fail-Json $result "timeout while waiting for string regex $regex in file $path to not match"
+ }
+ }
+ }
+} elseif ($null -ne $port) {
+ if ($state -in @("started","present")) {
+ # check that the port is online and is listening
+ $start_time = Get-Date
+ $complete = $false
+ while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) {
+ $attempts += 1
+ $port_result = Test-Port -hostname $hostname -port $port
+ if ($port_result -eq $true) {
+ $complete = $true
+ break
+ }
+
+ Start-Sleep -Seconds $sleep
+ }
+
+ if ($complete -eq $false) {
+ $result.elapsed = ((Get-Date) - $module_start).TotalSeconds
+ $result.wait_attempts = $attempts
+ Fail-Json $result "timeout while waiting for $($hostname):$port to start listening"
+ }
+ } elseif ($state -in @("stopped","absent")) {
+ # check that the port is offline and is not listening
+ $start_time = Get-Date
+ $complete = $false
+ while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) {
+ $attempts += 1
+ $port_result = Test-Port -hostname $hostname -port $port
+ if ($port_result -eq $false) {
+ $complete = $true
+ break
+ }
+
+ Start-Sleep -Seconds $sleep
+ }
+
+ if ($complete -eq $false) {
+ $result.elapsed = ((Get-Date) - $module_start).TotalSeconds
+ $result.wait_attempts = $attempts
+ Fail-Json $result "timeout while waiting for $($hostname):$port to stop listening"
+ }
+ } elseif ($state -eq "drained") {
+ # check that the local port is online but has no active connections
+ $start_time = Get-Date
+ $complete = $false
+ while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) {
+ $attempts += 1
+ $active_connections = Get-PortConnections -hostname $hostname -port $port
+ if ($null -eq $active_connections) {
+ $complete = $true
+ break
+ } elseif ($active_connections.Count -eq 0) {
+ # no connections on port
+ $complete = $true
+ break
+ } else {
+ # there are listeners, check if we should ignore any hosts
+ if ($null -ne $exclude_hosts) {
+ $connection_info = $active_connections
+ foreach ($exclude_host in $exclude_hosts) {
+ try {
+ $exclude_ips = [System.Net.Dns]::GetHostAddresses($exclude_host) | ForEach-Object { Write-Output $_.IPAddressToString }
+ $connection_info = $connection_info | Where-Object { $_ -notin $exclude_ips }
+ } catch { # ignore invalid hostnames
+ Add-Warning -obj $result -message "Invalid hostname specified $exclude_host"
+ }
+ }
+
+ if ($connection_info.Count -eq 0) {
+ $complete = $true
+ break
+ }
+ }
+ }
+
+ Start-Sleep -Seconds $sleep
+ }
+
+ if ($complete -eq $false) {
+ $result.elapsed = ((Get-Date) - $module_start).TotalSeconds
+ $result.wait_attempts = $attempts
+ Fail-Json $result "timeout while waiting for $($hostname):$port to drain"
+ }
+ }
+}
+
+$result.elapsed = ((Get-Date) - $module_start).TotalSeconds
+$result.wait_attempts = $attempts
+
+Exit-Json $result
diff --git a/test/support/windows-integration/plugins/modules/win_wait_for.py b/test/support/windows-integration/plugins/modules/win_wait_for.py
new file mode 100644
index 00000000..85721e7d
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_wait_for.py
@@ -0,0 +1,155 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub, actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_wait_for
+version_added: '2.4'
+short_description: Waits for a condition before continuing
+description:
+- You can wait for a set amount of time C(timeout), this is the default if
+ nothing is specified.
+- Waiting for a port to become available is useful for when services are not
+ immediately available after their init scripts return which is true of
+ certain Java application servers.
+- You can wait for a file to exist or not exist on the filesystem.
+- This module can also be used to wait for a regex match string to be present
+ in a file.
+- You can wait for active connections to be closed before continuing on a
+ local port.
+options:
+ connect_timeout:
+ description:
+ - The maximum number of seconds to wait for a connection to happen before
+ closing and retrying.
+ type: int
+ default: 5
+ delay:
+ description:
+ - The number of seconds to wait before starting to poll.
+ type: int
+ exclude_hosts:
+ description:
+ - The list of hosts or IPs to ignore when looking for active TCP
+ connections when C(state=drained).
+ type: list
+ host:
+ description:
+ - A resolvable hostname or IP address to wait for.
+ - If C(state=drained) then it will only check for connections on the IP
+ specified, you can use '0.0.0.0' to use all host IPs.
+ type: str
+ default: '127.0.0.1'
+ path:
+ description:
+ - The path to a file on the filesystem to check.
+ - If C(state) is present or started then it will wait until the file
+ exists.
+ - If C(state) is absent then it will wait until the file does not exist.
+ type: path
+ port:
+ description:
+ - The port number to poll on C(host).
+ type: int
+ regex:
+ description:
+ - Can be used to match a string in a file.
+ - If C(state) is present or started then it will wait until the regex
+ matches.
+ - If C(state) is absent then it will wait until the regex does not match.
+ - Defaults to a multiline regex.
+ type: str
+ aliases: [ "search_regex", "regexp" ]
+ sleep:
+ description:
+ - Number of seconds to sleep between checks.
+ type: int
+ default: 1
+ state:
+ description:
+ - When checking a port, C(started) will ensure the port is open, C(stopped)
+ will check that is it closed and C(drained) will check for active
+ connections.
+ - When checking for a file or a search string C(present) or C(started) will
+ ensure that the file or string is present, C(absent) will check that the
+ file or search string is absent or removed.
+ type: str
+ choices: [ absent, drained, present, started, stopped ]
+ default: started
+ timeout:
+ description:
+ - The maximum number of seconds to wait for.
+ type: int
+ default: 300
+seealso:
+- module: wait_for
+- module: win_wait_for_process
+author:
+- Jordan Borean (@jborean93)
+'''
+
+EXAMPLES = r'''
+- name: Wait 300 seconds for port 8000 to become open on the host, don't start checking for 10 seconds
+ win_wait_for:
+ port: 8000
+ delay: 10
+
+- name: Wait 150 seconds for port 8000 of any IP to close active connections
+ win_wait_for:
+ host: 0.0.0.0
+ port: 8000
+ state: drained
+ timeout: 150
+
+- name: Wait for port 8000 of any IP to close active connection, ignoring certain hosts
+ win_wait_for:
+ host: 0.0.0.0
+ port: 8000
+ state: drained
+ exclude_hosts: ['10.2.1.2', '10.2.1.3']
+
+- name: Wait for file C:\temp\log.txt to exist before continuing
+ win_wait_for:
+ path: C:\temp\log.txt
+
+- name: Wait until process complete is in the file before continuing
+ win_wait_for:
+ path: C:\temp\log.txt
+ regex: process complete
+
+- name: Wait until file is removed
+ win_wait_for:
+ path: C:\temp\log.txt
+ state: absent
+
+- name: Wait until port 1234 is offline but try every 10 seconds
+ win_wait_for:
+ port: 1234
+ state: absent
+ sleep: 10
+'''
+
+RETURN = r'''
+wait_attempts:
+ description: The number of attempts to poll the file or port before module
+ finishes.
+ returned: always
+ type: int
+ sample: 1
+elapsed:
+ description: The elapsed seconds between the start of poll and the end of the
+ module. This includes the delay if the option is set.
+ returned: always
+ type: float
+ sample: 2.1406487
+'''
diff --git a/test/support/windows-integration/plugins/modules/win_whoami.ps1 b/test/support/windows-integration/plugins/modules/win_whoami.ps1
new file mode 100644
index 00000000..6c9965af
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_whoami.ps1
@@ -0,0 +1,837 @@
+#!powershell
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+#Requires -Module Ansible.ModuleUtils.Legacy
+#Requires -Module Ansible.ModuleUtils.CamelConversion
+
+$ErrorActionPreference = "Stop"
+
+$params = Parse-Args $args -supports_check_mode $true
+$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP
+
+$session_util = @'
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Security.Principal;
+using System.Text;
+
+namespace Ansible
+{
+ public class SessionInfo
+ {
+ // SECURITY_LOGON_SESSION_DATA
+ public UInt64 LogonId { get; internal set; }
+ public Sid Account { get; internal set; }
+ public string LoginDomain { get; internal set; }
+ public string AuthenticationPackage { get; internal set; }
+ public SECURITY_LOGON_TYPE LogonType { get; internal set; }
+ public string LoginTime { get; internal set; }
+ public string LogonServer { get; internal set; }
+ public string DnsDomainName { get; internal set; }
+ public string Upn { get; internal set; }
+ public ArrayList UserFlags { get; internal set; }
+
+ // TOKEN_STATISTICS
+ public SECURITY_IMPERSONATION_LEVEL ImpersonationLevel { get; internal set; }
+ public TOKEN_TYPE TokenType { get; internal set; }
+
+ // TOKEN_GROUPS
+ public ArrayList Groups { get; internal set; }
+ public ArrayList Rights { get; internal set; }
+
+ // TOKEN_MANDATORY_LABEL
+ public Sid Label { get; internal set; }
+
+ // TOKEN_PRIVILEGES
+ public Hashtable Privileges { get; internal set; }
+ }
+
+ public class Win32Exception : System.ComponentModel.Win32Exception
+ {
+ private string _msg;
+ public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { }
+ public Win32Exception(int errorCode, string message) : base(errorCode)
+ {
+ _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode);
+ }
+ public override string Message { get { return _msg; } }
+ public static explicit operator Win32Exception(string message) { return new Win32Exception(message); }
+ }
+
+ [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
+ public struct LSA_UNICODE_STRING
+ {
+ public UInt16 Length;
+ public UInt16 MaximumLength;
+ public IntPtr buffer;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct LUID
+ {
+ public UInt32 LowPart;
+ public Int32 HighPart;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SECURITY_LOGON_SESSION_DATA
+ {
+ public UInt32 Size;
+ public LUID LogonId;
+ public LSA_UNICODE_STRING Username;
+ public LSA_UNICODE_STRING LoginDomain;
+ public LSA_UNICODE_STRING AuthenticationPackage;
+ public SECURITY_LOGON_TYPE LogonType;
+ public UInt32 Session;
+ public IntPtr Sid;
+ public UInt64 LoginTime;
+ public LSA_UNICODE_STRING LogonServer;
+ public LSA_UNICODE_STRING DnsDomainName;
+ public LSA_UNICODE_STRING Upn;
+ public UInt32 UserFlags;
+ public LSA_LAST_INTER_LOGON_INFO LastLogonInfo;
+ public LSA_UNICODE_STRING LogonScript;
+ public LSA_UNICODE_STRING ProfilePath;
+ public LSA_UNICODE_STRING HomeDirectory;
+ public LSA_UNICODE_STRING HomeDirectoryDrive;
+ public UInt64 LogoffTime;
+ public UInt64 KickOffTime;
+ public UInt64 PasswordLastSet;
+ public UInt64 PasswordCanChange;
+ public UInt64 PasswordMustChange;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct LSA_LAST_INTER_LOGON_INFO
+ {
+ public UInt64 LastSuccessfulLogon;
+ public UInt64 LastFailedLogon;
+ public UInt32 FailedAttemptCountSinceLastSuccessfulLogon;
+ }
+
+ public enum TOKEN_TYPE
+ {
+ TokenPrimary = 1,
+ TokenImpersonation
+ }
+
+ public enum SECURITY_IMPERSONATION_LEVEL
+ {
+ SecurityAnonymous,
+ SecurityIdentification,
+ SecurityImpersonation,
+ SecurityDelegation
+ }
+
+ public enum SECURITY_LOGON_TYPE
+ {
+ System = 0, // Used only by the Sytem account
+ Interactive = 2,
+ Network,
+ Batch,
+ Service,
+ Proxy,
+ Unlock,
+ NetworkCleartext,
+ NewCredentials,
+ RemoteInteractive,
+ CachedInteractive,
+ CachedRemoteInteractive,
+ CachedUnlock
+ }
+
+ [Flags]
+ public enum TokenGroupAttributes : uint
+ {
+ SE_GROUP_ENABLED = 0x00000004,
+ SE_GROUP_ENABLED_BY_DEFAULT = 0x00000002,
+ SE_GROUP_INTEGRITY = 0x00000020,
+ SE_GROUP_INTEGRITY_ENABLED = 0x00000040,
+ SE_GROUP_LOGON_ID = 0xC0000000,
+ SE_GROUP_MANDATORY = 0x00000001,
+ SE_GROUP_OWNER = 0x00000008,
+ SE_GROUP_RESOURCE = 0x20000000,
+ SE_GROUP_USE_FOR_DENY_ONLY = 0x00000010,
+ }
+
+ [Flags]
+ public enum UserFlags : uint
+ {
+ LOGON_OPTIMIZED = 0x4000,
+ LOGON_WINLOGON = 0x8000,
+ LOGON_PKINIT = 0x10000,
+ LOGON_NOT_OPTMIZED = 0x20000,
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SID_AND_ATTRIBUTES
+ {
+ public IntPtr Sid;
+ public UInt32 Attributes;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct LUID_AND_ATTRIBUTES
+ {
+ public LUID Luid;
+ public UInt32 Attributes;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct TOKEN_GROUPS
+ {
+ public UInt32 GroupCount;
+ [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)]
+ public SID_AND_ATTRIBUTES[] Groups;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct TOKEN_MANDATORY_LABEL
+ {
+ public SID_AND_ATTRIBUTES Label;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct TOKEN_STATISTICS
+ {
+ public LUID TokenId;
+ public LUID AuthenticationId;
+ public UInt64 ExpirationTime;
+ public TOKEN_TYPE TokenType;
+ public SECURITY_IMPERSONATION_LEVEL ImpersonationLevel;
+ public UInt32 DynamicCharged;
+ public UInt32 DynamicAvailable;
+ public UInt32 GroupCount;
+ public UInt32 PrivilegeCount;
+ public LUID ModifiedId;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct TOKEN_PRIVILEGES
+ {
+ public UInt32 PrivilegeCount;
+ [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)]
+ public LUID_AND_ATTRIBUTES[] Privileges;
+ }
+
+ public class AccessToken : IDisposable
+ {
+ public enum TOKEN_INFORMATION_CLASS
+ {
+ TokenUser = 1,
+ TokenGroups,
+ TokenPrivileges,
+ TokenOwner,
+ TokenPrimaryGroup,
+ TokenDefaultDacl,
+ TokenSource,
+ TokenType,
+ TokenImpersonationLevel,
+ TokenStatistics,
+ TokenRestrictedSids,
+ TokenSessionId,
+ TokenGroupsAndPrivileges,
+ TokenSessionReference,
+ TokenSandBoxInert,
+ TokenAuditPolicy,
+ TokenOrigin,
+ TokenElevationType,
+ TokenLinkedToken,
+ TokenElevation,
+ TokenHasRestrictions,
+ TokenAccessInformation,
+ TokenVirtualizationAllowed,
+ TokenVirtualizationEnabled,
+ TokenIntegrityLevel,
+ TokenUIAccess,
+ TokenMandatoryPolicy,
+ TokenLogonSid,
+ TokenIsAppContainer,
+ TokenCapabilities,
+ TokenAppContainerSid,
+ TokenAppContainerNumber,
+ TokenUserClaimAttributes,
+ TokenDeviceClaimAttributes,
+ TokenRestrictedUserClaimAttributes,
+ TokenRestrictedDeviceClaimAttributes,
+ TokenDeviceGroups,
+ TokenRestrictedDeviceGroups,
+ TokenSecurityAttributes,
+ TokenIsRestricted,
+ MaxTokenInfoClass
+ }
+
+ public IntPtr hToken = IntPtr.Zero;
+
+ [DllImport("kernel32.dll")]
+ private static extern IntPtr GetCurrentProcess();
+
+ [DllImport("advapi32.dll", SetLastError = true)]
+ private static extern bool OpenProcessToken(
+ IntPtr ProcessHandle,
+ TokenAccessLevels DesiredAccess,
+ out IntPtr TokenHandle);
+
+ [DllImport("advapi32.dll", SetLastError = true)]
+ private static extern bool GetTokenInformation(
+ IntPtr TokenHandle,
+ TOKEN_INFORMATION_CLASS TokenInformationClass,
+ IntPtr TokenInformation,
+ UInt32 TokenInformationLength,
+ out UInt32 ReturnLength);
+
+ public AccessToken(TokenAccessLevels tokenAccessLevels)
+ {
+ IntPtr currentProcess = GetCurrentProcess();
+ if (!OpenProcessToken(currentProcess, tokenAccessLevels, out hToken))
+ throw new Win32Exception("OpenProcessToken() for current process failed");
+ }
+
+ public IntPtr GetTokenInformation<T>(out T tokenInformation, TOKEN_INFORMATION_CLASS tokenClass)
+ {
+ UInt32 tokenLength = 0;
+ GetTokenInformation(hToken, tokenClass, IntPtr.Zero, 0, out tokenLength);
+
+ IntPtr infoPtr = Marshal.AllocHGlobal((int)tokenLength);
+
+ if (!GetTokenInformation(hToken, tokenClass, infoPtr, tokenLength, out tokenLength))
+ throw new Win32Exception(String.Format("GetTokenInformation() data for {0} failed", tokenClass.ToString()));
+
+ tokenInformation = (T)Marshal.PtrToStructure(infoPtr, typeof(T));
+ return infoPtr;
+ }
+
+ public void Dispose()
+ {
+ GC.SuppressFinalize(this);
+ }
+
+ ~AccessToken() { Dispose(); }
+ }
+
+ public class LsaHandle : IDisposable
+ {
+ [Flags]
+ public enum DesiredAccess : uint
+ {
+ POLICY_VIEW_LOCAL_INFORMATION = 0x00000001,
+ POLICY_VIEW_AUDIT_INFORMATION = 0x00000002,
+ POLICY_GET_PRIVATE_INFORMATION = 0x00000004,
+ POLICY_TRUST_ADMIN = 0x00000008,
+ POLICY_CREATE_ACCOUNT = 0x00000010,
+ POLICY_CREATE_SECRET = 0x00000020,
+ POLICY_CREATE_PRIVILEGE = 0x00000040,
+ POLICY_SET_DEFAULT_QUOTA_LIMITS = 0x00000080,
+ POLICY_SET_AUDIT_REQUIREMENTS = 0x00000100,
+ POLICY_AUDIT_LOG_ADMIN = 0x00000200,
+ POLICY_SERVER_ADMIN = 0x00000400,
+ POLICY_LOOKUP_NAMES = 0x00000800,
+ POLICY_NOTIFICATION = 0x00001000
+ }
+
+ public IntPtr handle = IntPtr.Zero;
+
+ [DllImport("advapi32.dll", SetLastError = true)]
+ private static extern uint LsaOpenPolicy(
+ LSA_UNICODE_STRING[] SystemName,
+ ref LSA_OBJECT_ATTRIBUTES ObjectAttributes,
+ DesiredAccess AccessMask,
+ out IntPtr PolicyHandle);
+
+ [DllImport("advapi32.dll", SetLastError = true)]
+ private static extern uint LsaClose(
+ IntPtr ObjectHandle);
+
+ [DllImport("advapi32.dll", SetLastError = false)]
+ private static extern int LsaNtStatusToWinError(
+ uint Status);
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct LSA_OBJECT_ATTRIBUTES
+ {
+ public int Length;
+ public IntPtr RootDirectory;
+ public IntPtr ObjectName;
+ public int Attributes;
+ public IntPtr SecurityDescriptor;
+ public IntPtr SecurityQualityOfService;
+ }
+
+ public LsaHandle(DesiredAccess desiredAccess)
+ {
+ LSA_OBJECT_ATTRIBUTES lsaAttr;
+ lsaAttr.RootDirectory = IntPtr.Zero;
+ lsaAttr.ObjectName = IntPtr.Zero;
+ lsaAttr.Attributes = 0;
+ lsaAttr.SecurityDescriptor = IntPtr.Zero;
+ lsaAttr.SecurityQualityOfService = IntPtr.Zero;
+ lsaAttr.Length = Marshal.SizeOf(typeof(LSA_OBJECT_ATTRIBUTES));
+ LSA_UNICODE_STRING[] system = new LSA_UNICODE_STRING[1];
+ system[0].buffer = IntPtr.Zero;
+
+ uint res = LsaOpenPolicy(system, ref lsaAttr, desiredAccess, out handle);
+ if (res != 0)
+ throw new Win32Exception(LsaNtStatusToWinError(res), "LsaOpenPolicy() failed");
+ }
+
+ public void Dispose()
+ {
+ if (handle != IntPtr.Zero)
+ {
+ LsaClose(handle);
+ handle = IntPtr.Zero;
+ }
+ GC.SuppressFinalize(this);
+ }
+
+ ~LsaHandle() { Dispose(); }
+ }
+
+ public class Sid
+ {
+ public string SidString { get; internal set; }
+ public string DomainName { get; internal set; }
+ public string AccountName { get; internal set; }
+ public SID_NAME_USE SidType { get; internal set; }
+
+ public enum SID_NAME_USE
+ {
+ SidTypeUser = 1,
+ SidTypeGroup,
+ SidTypeDomain,
+ SidTypeAlias,
+ SidTypeWellKnownGroup,
+ SidTypeDeletedAccount,
+ SidTypeInvalid,
+ SidTypeUnknown,
+ SidTypeComputer,
+ SidTypeLabel,
+ SidTypeLogon,
+ }
+
+ [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)]
+ private static extern bool LookupAccountSid(
+ string lpSystemName,
+ [MarshalAs(UnmanagedType.LPArray)]
+ byte[] Sid,
+ StringBuilder lpName,
+ ref UInt32 cchName,
+ StringBuilder ReferencedDomainName,
+ ref UInt32 cchReferencedDomainName,
+ out SID_NAME_USE peUse);
+
+ public Sid(IntPtr sidPtr)
+ {
+ SecurityIdentifier sid;
+ try
+ {
+ sid = new SecurityIdentifier(sidPtr);
+ }
+ catch (Exception e)
+ {
+ throw new ArgumentException(String.Format("Failed to cast IntPtr to SecurityIdentifier: {0}", e));
+ }
+
+ SetSidInfo(sid);
+ }
+
+ public Sid(SecurityIdentifier sid)
+ {
+ SetSidInfo(sid);
+ }
+
+ public override string ToString()
+ {
+ return SidString;
+ }
+
+ private void SetSidInfo(SecurityIdentifier sid)
+ {
+ byte[] sidBytes = new byte[sid.BinaryLength];
+ sid.GetBinaryForm(sidBytes, 0);
+
+ StringBuilder lpName = new StringBuilder();
+ UInt32 cchName = 0;
+ StringBuilder referencedDomainName = new StringBuilder();
+ UInt32 cchReferencedDomainName = 0;
+ SID_NAME_USE peUse;
+ LookupAccountSid(null, sidBytes, lpName, ref cchName, referencedDomainName, ref cchReferencedDomainName, out peUse);
+
+ lpName.EnsureCapacity((int)cchName);
+ referencedDomainName.EnsureCapacity((int)cchReferencedDomainName);
+
+ SidString = sid.ToString();
+ if (!LookupAccountSid(null, sidBytes, lpName, ref cchName, referencedDomainName, ref cchReferencedDomainName, out peUse))
+ {
+ int lastError = Marshal.GetLastWin32Error();
+
+ if (lastError != 1332 && lastError != 1789) // Fails to lookup Logon Sid
+ {
+ throw new Win32Exception(lastError, String.Format("LookupAccountSid() failed for SID: {0} {1}", sid.ToString(), lastError));
+ }
+ else if (SidString.StartsWith("S-1-5-5-"))
+ {
+ AccountName = String.Format("LogonSessionId_{0}", SidString.Substring(8));
+ DomainName = "NT AUTHORITY";
+ SidType = SID_NAME_USE.SidTypeLogon;
+ }
+ else
+ {
+ AccountName = null;
+ DomainName = null;
+ SidType = SID_NAME_USE.SidTypeUnknown;
+ }
+ }
+ else
+ {
+ AccountName = lpName.ToString();
+ DomainName = referencedDomainName.ToString();
+ SidType = peUse;
+ }
+ }
+ }
+
+ public class SessionUtil
+ {
+ [DllImport("secur32.dll", SetLastError = false)]
+ private static extern uint LsaFreeReturnBuffer(
+ IntPtr Buffer);
+
+ [DllImport("secur32.dll", SetLastError = false)]
+ private static extern uint LsaEnumerateLogonSessions(
+ out UInt64 LogonSessionCount,
+ out IntPtr LogonSessionList);
+
+ [DllImport("secur32.dll", SetLastError = false)]
+ private static extern uint LsaGetLogonSessionData(
+ IntPtr LogonId,
+ out IntPtr ppLogonSessionData);
+
+ [DllImport("advapi32.dll", SetLastError = false)]
+ private static extern int LsaNtStatusToWinError(
+ uint Status);
+
+ [DllImport("advapi32", SetLastError = true)]
+ private static extern uint LsaEnumerateAccountRights(
+ IntPtr PolicyHandle,
+ IntPtr AccountSid,
+ out IntPtr UserRights,
+ out UInt64 CountOfRights);
+
+ [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)]
+ private static extern bool LookupPrivilegeName(
+ string lpSystemName,
+ ref LUID lpLuid,
+ StringBuilder lpName,
+ ref UInt32 cchName);
+
+ private const UInt32 SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001;
+ private const UInt32 SE_PRIVILEGE_ENABLED = 0x00000002;
+ private const UInt32 STATUS_OBJECT_NAME_NOT_FOUND = 0xC0000034;
+ private const UInt32 STATUS_ACCESS_DENIED = 0xC0000022;
+
+ public static SessionInfo GetSessionInfo()
+ {
+ AccessToken accessToken = new AccessToken(TokenAccessLevels.Query);
+
+ // Get Privileges
+ Hashtable privilegeInfo = new Hashtable();
+ TOKEN_PRIVILEGES privileges;
+ IntPtr privilegesPtr = accessToken.GetTokenInformation(out privileges, AccessToken.TOKEN_INFORMATION_CLASS.TokenPrivileges);
+ LUID_AND_ATTRIBUTES[] luidAndAttributes = new LUID_AND_ATTRIBUTES[privileges.PrivilegeCount];
+ try
+ {
+ PtrToStructureArray(luidAndAttributes, privilegesPtr.ToInt64() + Marshal.SizeOf(privileges.PrivilegeCount));
+ }
+ finally
+ {
+ Marshal.FreeHGlobal(privilegesPtr);
+ }
+ foreach (LUID_AND_ATTRIBUTES luidAndAttribute in luidAndAttributes)
+ {
+ LUID privLuid = luidAndAttribute.Luid;
+ UInt32 privNameLen = 0;
+ StringBuilder privName = new StringBuilder();
+ LookupPrivilegeName(null, ref privLuid, null, ref privNameLen);
+ privName.EnsureCapacity((int)(privNameLen + 1));
+ if (!LookupPrivilegeName(null, ref privLuid, privName, ref privNameLen))
+ throw new Win32Exception("LookupPrivilegeName() failed");
+
+ string state = "disabled";
+ if ((luidAndAttribute.Attributes & SE_PRIVILEGE_ENABLED) == SE_PRIVILEGE_ENABLED)
+ state = "enabled";
+ if ((luidAndAttribute.Attributes & SE_PRIVILEGE_ENABLED_BY_DEFAULT) == SE_PRIVILEGE_ENABLED_BY_DEFAULT)
+ state = "enabled-by-default";
+ privilegeInfo.Add(privName.ToString(), state);
+ }
+
+ // Get Current Process LogonSID, User Rights and Groups
+ ArrayList userRights = new ArrayList();
+ ArrayList userGroups = new ArrayList();
+ TOKEN_GROUPS groups;
+ IntPtr groupsPtr = accessToken.GetTokenInformation(out groups, AccessToken.TOKEN_INFORMATION_CLASS.TokenGroups);
+ SID_AND_ATTRIBUTES[] sidAndAttributes = new SID_AND_ATTRIBUTES[groups.GroupCount];
+ LsaHandle lsaHandle = null;
+ // We can only get rights if we are an admin
+ if (new WindowsPrincipal(WindowsIdentity.GetCurrent()).IsInRole(WindowsBuiltInRole.Administrator))
+ lsaHandle = new LsaHandle(LsaHandle.DesiredAccess.POLICY_LOOKUP_NAMES);
+ try
+ {
+ PtrToStructureArray(sidAndAttributes, groupsPtr.ToInt64() + IntPtr.Size);
+ foreach (SID_AND_ATTRIBUTES sidAndAttribute in sidAndAttributes)
+ {
+ TokenGroupAttributes attributes = (TokenGroupAttributes)sidAndAttribute.Attributes;
+ if (attributes.HasFlag(TokenGroupAttributes.SE_GROUP_ENABLED) && lsaHandle != null)
+ {
+ ArrayList rights = GetAccountRights(lsaHandle.handle, sidAndAttribute.Sid);
+ foreach (string right in rights)
+ {
+ // Includes both Privileges and Account Rights, only add the ones with Logon in the name
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/bb545671(v=vs.85).aspx
+ if (!userRights.Contains(right) && right.Contains("Logon"))
+ userRights.Add(right);
+ }
+ }
+ // Do not include the Logon SID in the groups category
+ if (!attributes.HasFlag(TokenGroupAttributes.SE_GROUP_LOGON_ID))
+ {
+ Hashtable groupInfo = new Hashtable();
+ Sid group = new Sid(sidAndAttribute.Sid);
+ ArrayList groupAttributes = new ArrayList();
+ foreach (TokenGroupAttributes attribute in Enum.GetValues(typeof(TokenGroupAttributes)))
+ {
+ if (attributes.HasFlag(attribute))
+ {
+ string attributeName = attribute.ToString().Substring(9);
+ attributeName = attributeName.Replace('_', ' ');
+ attributeName = attributeName.First().ToString().ToUpper() + attributeName.Substring(1).ToLower();
+ groupAttributes.Add(attributeName);
+ }
+ }
+ // Using snake_case here as I can't generically convert all dict keys in PS (see Privileges)
+ groupInfo.Add("sid", group.SidString);
+ groupInfo.Add("domain_name", group.DomainName);
+ groupInfo.Add("account_name", group.AccountName);
+ groupInfo.Add("type", group.SidType);
+ groupInfo.Add("attributes", groupAttributes);
+ userGroups.Add(groupInfo);
+ }
+ }
+ }
+ finally
+ {
+ Marshal.FreeHGlobal(groupsPtr);
+ if (lsaHandle != null)
+ lsaHandle.Dispose();
+ }
+
+ // Get Integrity Level
+ Sid integritySid = null;
+ TOKEN_MANDATORY_LABEL mandatoryLabel;
+ IntPtr mandatoryLabelPtr = accessToken.GetTokenInformation(out mandatoryLabel, AccessToken.TOKEN_INFORMATION_CLASS.TokenIntegrityLevel);
+ Marshal.FreeHGlobal(mandatoryLabelPtr);
+ integritySid = new Sid(mandatoryLabel.Label.Sid);
+
+ // Get Token Statistics
+ TOKEN_STATISTICS tokenStats;
+ IntPtr tokenStatsPtr = accessToken.GetTokenInformation(out tokenStats, AccessToken.TOKEN_INFORMATION_CLASS.TokenStatistics);
+ Marshal.FreeHGlobal(tokenStatsPtr);
+
+ SessionInfo sessionInfo = GetSessionDataForLogonSession(tokenStats.AuthenticationId);
+ sessionInfo.Groups = userGroups;
+ sessionInfo.Label = integritySid;
+ sessionInfo.ImpersonationLevel = tokenStats.ImpersonationLevel;
+ sessionInfo.TokenType = tokenStats.TokenType;
+ sessionInfo.Privileges = privilegeInfo;
+ sessionInfo.Rights = userRights;
+ return sessionInfo;
+ }
+
+ private static ArrayList GetAccountRights(IntPtr lsaHandle, IntPtr sid)
+ {
+ UInt32 res;
+ ArrayList rights = new ArrayList();
+ IntPtr userRightsPointer = IntPtr.Zero;
+ UInt64 countOfRights = 0;
+
+ res = LsaEnumerateAccountRights(lsaHandle, sid, out userRightsPointer, out countOfRights);
+ if (res != 0 && res != STATUS_OBJECT_NAME_NOT_FOUND)
+ throw new Win32Exception(LsaNtStatusToWinError(res), "LsaEnumerateAccountRights() failed");
+ else if (res != STATUS_OBJECT_NAME_NOT_FOUND)
+ {
+ LSA_UNICODE_STRING[] userRights = new LSA_UNICODE_STRING[countOfRights];
+ PtrToStructureArray(userRights, userRightsPointer.ToInt64());
+ rights = new ArrayList();
+ foreach (LSA_UNICODE_STRING right in userRights)
+ rights.Add(Marshal.PtrToStringUni(right.buffer));
+ }
+
+ return rights;
+ }
+
+ private static SessionInfo GetSessionDataForLogonSession(LUID logonSession)
+ {
+ uint res;
+ UInt64 count = 0;
+ IntPtr luidPtr = IntPtr.Zero;
+ SessionInfo sessionInfo = null;
+ UInt64 processDataId = ConvertLuidToUint(logonSession);
+
+ res = LsaEnumerateLogonSessions(out count, out luidPtr);
+ if (res != 0)
+ throw new Win32Exception(LsaNtStatusToWinError(res), "LsaEnumerateLogonSessions() failed");
+ Int64 luidAddr = luidPtr.ToInt64();
+
+ try
+ {
+ for (UInt64 i = 0; i < count; i++)
+ {
+ IntPtr dataPointer = IntPtr.Zero;
+ res = LsaGetLogonSessionData(luidPtr, out dataPointer);
+ if (res == STATUS_ACCESS_DENIED) // Non admins won't be able to get info for session's that are not their own
+ {
+ luidPtr = new IntPtr(luidPtr.ToInt64() + Marshal.SizeOf(typeof(LUID)));
+ continue;
+ }
+ else if (res != 0)
+ throw new Win32Exception(LsaNtStatusToWinError(res), String.Format("LsaGetLogonSessionData() failed {0}", res));
+
+ SECURITY_LOGON_SESSION_DATA sessionData = (SECURITY_LOGON_SESSION_DATA)Marshal.PtrToStructure(dataPointer, typeof(SECURITY_LOGON_SESSION_DATA));
+ UInt64 sessionDataid = ConvertLuidToUint(sessionData.LogonId);
+
+ if (sessionDataid == processDataId)
+ {
+ ArrayList userFlags = new ArrayList();
+ UserFlags flags = (UserFlags)sessionData.UserFlags;
+ foreach (UserFlags flag in Enum.GetValues(typeof(UserFlags)))
+ {
+ if (flags.HasFlag(flag))
+ {
+ string flagName = flag.ToString().Substring(6);
+ flagName = flagName.Replace('_', ' ');
+ flagName = flagName.First().ToString().ToUpper() + flagName.Substring(1).ToLower();
+ userFlags.Add(flagName);
+ }
+ }
+
+ sessionInfo = new SessionInfo()
+ {
+ AuthenticationPackage = Marshal.PtrToStringUni(sessionData.AuthenticationPackage.buffer),
+ DnsDomainName = Marshal.PtrToStringUni(sessionData.DnsDomainName.buffer),
+ LoginDomain = Marshal.PtrToStringUni(sessionData.LoginDomain.buffer),
+ LoginTime = ConvertIntegerToDateString(sessionData.LoginTime),
+ LogonId = ConvertLuidToUint(sessionData.LogonId),
+ LogonServer = Marshal.PtrToStringUni(sessionData.LogonServer.buffer),
+ LogonType = sessionData.LogonType,
+ Upn = Marshal.PtrToStringUni(sessionData.Upn.buffer),
+ UserFlags = userFlags,
+ Account = new Sid(sessionData.Sid)
+ };
+ break;
+ }
+ luidPtr = new IntPtr(luidPtr.ToInt64() + Marshal.SizeOf(typeof(LUID)));
+ }
+ }
+ finally
+ {
+ LsaFreeReturnBuffer(new IntPtr(luidAddr));
+ }
+
+ if (sessionInfo == null)
+ throw new Exception(String.Format("Could not find the data for logon session {0}", processDataId));
+ return sessionInfo;
+ }
+
+ private static string ConvertIntegerToDateString(UInt64 time)
+ {
+ if (time == 0)
+ return null;
+ if (time > (UInt64)DateTime.MaxValue.ToFileTime())
+ return null;
+
+ DateTime dateTime = DateTime.FromFileTime((long)time);
+ return dateTime.ToString("o");
+ }
+
+ private static UInt64 ConvertLuidToUint(LUID luid)
+ {
+ UInt32 low = luid.LowPart;
+ UInt64 high = (UInt64)luid.HighPart;
+ high = high << 32;
+ UInt64 uintValue = (high | (UInt64)low);
+ return uintValue;
+ }
+
+ private static void PtrToStructureArray<T>(T[] array, Int64 pointerAddress)
+ {
+ Int64 pointerOffset = pointerAddress;
+ for (int i = 0; i < array.Length; i++, pointerOffset += Marshal.SizeOf(typeof(T)))
+ array[i] = (T)Marshal.PtrToStructure(new IntPtr(pointerOffset), typeof(T));
+ }
+
+ public static IEnumerable<T> GetValues<T>()
+ {
+ return Enum.GetValues(typeof(T)).Cast<T>();
+ }
+ }
+}
+'@
+
+$original_tmp = $env:TMP
+$env:TMP = $_remote_tmp
+Add-Type -TypeDefinition $session_util
+$env:TMP = $original_tmp
+
+$session_info = [Ansible.SessionUtil]::GetSessionInfo()
+
+Function Convert-Value($value) {
+ $new_value = $value
+ if ($value -is [System.Collections.ArrayList]) {
+ $new_value = [System.Collections.ArrayList]@()
+ foreach ($list_value in $value) {
+ $new_list_value = Convert-Value -value $list_value
+ [void]$new_value.Add($new_list_value)
+ }
+ } elseif ($value -is [Hashtable]) {
+ $new_value = @{}
+ foreach ($entry in $value.GetEnumerator()) {
+ $entry_value = Convert-Value -value $entry.Value
+ # manually convert Sid type entry to remove the SidType prefix
+ if ($entry.Name -eq "type") {
+ $entry_value = $entry_value.Replace("SidType", "")
+ }
+ $new_value[$entry.Name] = $entry_value
+ }
+ } elseif ($value -is [Ansible.Sid]) {
+ $new_value = @{
+ sid = $value.SidString
+ account_name = $value.AccountName
+ domain_name = $value.DomainName
+ type = $value.SidType.ToString().Replace("SidType", "")
+ }
+ } elseif ($value -is [Enum]) {
+ $new_value = $value.ToString()
+ }
+
+ return ,$new_value
+}
+
+$result = @{
+ changed = $false
+}
+
+$properties = [type][Ansible.SessionInfo]
+foreach ($property in $properties.DeclaredProperties) {
+ $property_name = $property.Name
+ $property_value = $session_info.$property_name
+ $snake_name = Convert-StringToSnakeCase -string $property_name
+
+ $result.$snake_name = Convert-Value -value $property_value
+}
+
+Exit-Json -obj $result
diff --git a/test/support/windows-integration/plugins/modules/win_whoami.py b/test/support/windows-integration/plugins/modules/win_whoami.py
new file mode 100644
index 00000000..d647374b
--- /dev/null
+++ b/test/support/windows-integration/plugins/modules/win_whoami.py
@@ -0,0 +1,203 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# this is a windows documentation stub. actual code lives in the .ps1
+# file of the same name
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+ 'status': ['preview'],
+ 'supported_by': 'community'}
+
+DOCUMENTATION = r'''
+---
+module: win_whoami
+version_added: "2.5"
+short_description: Get information about the current user and process
+description:
+- Designed to return the same information as the C(whoami /all) command.
+- Also includes information missing from C(whoami) such as logon metadata like
+ logon rights, id, type.
+notes:
+- If running this module with a non admin user, the logon rights will be an
+ empty list as Administrator rights are required to query LSA for the
+ information.
+seealso:
+- module: win_credential
+- module: win_group_membership
+- module: win_user_right
+author:
+- Jordan Borean (@jborean93)
+'''
+
+EXAMPLES = r'''
+- name: Get whoami information
+ win_whoami:
+'''
+
+RETURN = r'''
+authentication_package:
+ description: The name of the authentication package used to authenticate the
+ user in the session.
+ returned: success
+ type: str
+ sample: Negotiate
+user_flags:
+ description: The user flags for the logon session, see UserFlags in
+ U(https://msdn.microsoft.com/en-us/library/windows/desktop/aa380128).
+ returned: success
+ type: str
+ sample: Winlogon
+upn:
+ description: The user principal name of the current user.
+ returned: success
+ type: str
+ sample: Administrator@DOMAIN.COM
+logon_type:
+ description: The logon type that identifies the logon method, see
+ U(https://msdn.microsoft.com/en-us/library/windows/desktop/aa380129.aspx).
+ returned: success
+ type: str
+ sample: Network
+privileges:
+ description: A dictionary of privileges and their state on the logon token.
+ returned: success
+ type: dict
+ sample: {
+ "SeChangeNotifyPrivileges": "enabled-by-default",
+ "SeRemoteShutdownPrivilege": "disabled",
+ "SeDebugPrivilege": "enabled"
+ }
+label:
+ description: The mandatory label set to the logon session.
+ returned: success
+ type: complex
+ contains:
+ domain_name:
+ description: The domain name of the label SID.
+ returned: success
+ type: str
+ sample: Mandatory Label
+ sid:
+ description: The SID in string form.
+ returned: success
+ type: str
+ sample: S-1-16-12288
+ account_name:
+ description: The account name of the label SID.
+ returned: success
+ type: str
+ sample: High Mandatory Level
+ type:
+ description: The type of SID.
+ returned: success
+ type: str
+ sample: Label
+impersonation_level:
+ description: The impersonation level of the token, only valid if
+ C(token_type) is C(TokenImpersonation), see
+ U(https://msdn.microsoft.com/en-us/library/windows/desktop/aa379572.aspx).
+ returned: success
+ type: str
+ sample: SecurityAnonymous
+login_time:
+ description: The logon time in ISO 8601 format
+ returned: success
+ type: str
+ sample: '2017-11-27T06:24:14.3321665+10:00'
+groups:
+ description: A list of groups and attributes that the user is a member of.
+ returned: success
+ type: list
+ sample: [
+ {
+ "account_name": "Domain Users",
+ "domain_name": "DOMAIN",
+ "attributes": [
+ "Mandatory",
+ "Enabled by default",
+ "Enabled"
+ ],
+ "sid": "S-1-5-21-1654078763-769949647-2968445802-513",
+ "type": "Group"
+ },
+ {
+ "account_name": "Administrators",
+ "domain_name": "BUILTIN",
+ "attributes": [
+ "Mandatory",
+ "Enabled by default",
+ "Enabled",
+ "Owner"
+ ],
+ "sid": "S-1-5-32-544",
+ "type": "Alias"
+ }
+ ]
+account:
+ description: The running account SID details.
+ returned: success
+ type: complex
+ contains:
+ domain_name:
+ description: The domain name of the account SID.
+ returned: success
+ type: str
+ sample: DOMAIN
+ sid:
+ description: The SID in string form.
+ returned: success
+ type: str
+ sample: S-1-5-21-1654078763-769949647-2968445802-500
+ account_name:
+ description: The account name of the account SID.
+ returned: success
+ type: str
+ sample: Administrator
+ type:
+ description: The type of SID.
+ returned: success
+ type: str
+ sample: User
+login_domain:
+ description: The name of the domain used to authenticate the owner of the
+ session.
+ returned: success
+ type: str
+ sample: DOMAIN
+rights:
+ description: A list of logon rights assigned to the logon.
+ returned: success and running user is a member of the local Administrators group
+ type: list
+ sample: [
+ "SeNetworkLogonRight",
+ "SeInteractiveLogonRight",
+ "SeBatchLogonRight",
+ "SeRemoteInteractiveLogonRight"
+ ]
+logon_server:
+ description: The name of the server used to authenticate the owner of the
+ logon session.
+ returned: success
+ type: str
+ sample: DC01
+logon_id:
+ description: The unique identifier of the logon session.
+ returned: success
+ type: int
+ sample: 20470143
+dns_domain_name:
+ description: The DNS name of the logon session, this is an empty string if
+ this is not set.
+ returned: success
+ type: str
+ sample: DOMAIN.COM
+token_type:
+ description: The token type to indicate whether it is a primary or
+ impersonation token.
+ returned: success
+ type: str
+ sample: TokenPrimary
+'''