diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c index 2515786d4156..221be84db5ad 100644 --- a/drivers/iommu/iommu.c +++ b/drivers/iommu/iommu.c @@ -4074,9 +4074,14 @@ int pci_dev_reset_iommu_prepare(struct pci_dev *pdev) * The pasid_array is mostly fenced by group->mutex, except one reader * in iommu_attach_handle_get(), so it's safe to read without xa_lock. */ - xa_for_each_start(&group->pasid_array, pasid, entry, 1) - iommu_remove_dev_pasid(&pdev->dev, pasid, - pasid_array_entry_to_domain(entry)); + if (pdev->dev.iommu->max_pasids > 0) { + xa_for_each_start(&group->pasid_array, pasid, entry, 1) { + struct iommu_domain *pasid_dom = + pasid_array_entry_to_domain(entry); + + iommu_remove_dev_pasid(&pdev->dev, pasid, pasid_dom); + } + } group->recovery_cnt++; return ret; @@ -4143,10 +4148,16 @@ void pci_dev_reset_iommu_done(struct pci_dev *pdev) * The pasid_array is mostly fenced by group->mutex, except one reader * in iommu_attach_handle_get(), so it's safe to read without xa_lock. */ - xa_for_each_start(&group->pasid_array, pasid, entry, 1) - WARN_ON(__iommu_set_group_pasid( - pasid_array_entry_to_domain(entry), group, pasid, - group->blocking_domain)); + if (pdev->dev.iommu->max_pasids > 0) { + xa_for_each_start(&group->pasid_array, pasid, entry, 1) { + struct iommu_domain *pasid_dom = + pasid_array_entry_to_domain(entry); + + WARN_ON(pasid_dom->ops->set_dev_pasid( + pasid_dom, &pdev->dev, pasid, + group->blocking_domain)); + } + } if (!WARN_ON(group->recovery_cnt == 0)) group->recovery_cnt--;