diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c index 05f78cfd1f1d..c21d1e3c4876 100644 --- a/drivers/iommu/iommu.c +++ b/drivers/iommu/iommu.c @@ -2623,19 +2623,9 @@ static int __iommu_map_domain_pgtbl(struct iommu_domain *domain, size_t orig_size = size; int ret = 0; - might_sleep_if(gfpflags_allow_blocking(gfp)); - - if (unlikely(!(domain->type & __IOMMU_DOMAIN_PAGING))) - return -EINVAL; - - if (WARN_ON(!ops->map_pages || domain->pgsize_bitmap == 0UL)) + if (WARN_ON(!ops->map_pages)) return -ENODEV; - /* Discourage passing strange GFP flags */ - if (WARN_ON_ONCE(gfp & (__GFP_COMP | __GFP_DMA | __GFP_DMA32 | - __GFP_HIGHMEM))) - return -EINVAL; - /* find out the minimum page size supported */ min_pagesz = 1 << __ffs(domain->pgsize_bitmap); @@ -2697,6 +2687,15 @@ int iommu_map_nosync(struct iommu_domain *domain, unsigned long iova, struct pt_iommu *pt = iommupt_from_domain(domain); int ret; + might_sleep_if(gfpflags_allow_blocking(gfp)); + + /* Discourage passing strange GFP flags or illegal domains */ + if (WARN_ON_ONCE(!(domain->type & __IOMMU_DOMAIN_PAGING) || + !domain->pgsize_bitmap || + (gfp & (__GFP_COMP | __GFP_DMA | __GFP_DMA32 | + __GFP_HIGHMEM)))) + return -EINVAL; + if (pt) { size_t mapped = 0; @@ -2706,11 +2705,12 @@ int iommu_map_nosync(struct iommu_domain *domain, unsigned long iova, iommu_unmap(domain, iova, mapped); return ret; } - return 0; + } else { + ret = __iommu_map_domain_pgtbl(domain, iova, paddr, size, prot, + gfp); + if (ret) + return ret; } - ret = __iommu_map_domain_pgtbl(domain, iova, paddr, size, prot, gfp); - if (ret) - return ret; trace_map(iova, paddr, size); iommu_debug_map(domain, paddr, size); @@ -2742,10 +2742,7 @@ __iommu_unmap_domain_pgtbl(struct iommu_domain *domain, unsigned long iova, size_t unmapped_page, unmapped = 0; unsigned int min_pagesz; - if (unlikely(!(domain->type & __IOMMU_DOMAIN_PAGING))) - return 0; - - if (WARN_ON(!ops->unmap_pages || domain->pgsize_bitmap == 0UL)) + if (WARN_ON(!ops->unmap_pages)) return 0; /* find out the minimum page size supported */ @@ -2764,8 +2761,6 @@ __iommu_unmap_domain_pgtbl(struct iommu_domain *domain, unsigned long iova, pr_debug("unmap this: iova 0x%lx size 0x%zx\n", iova, size); - iommu_debug_unmap_begin(domain, iova, size); - /* * Keep iterating until we either unmap 'size' bytes (or more) * or we hit an area that isn't mapped. @@ -2801,6 +2796,12 @@ static size_t __iommu_unmap(struct iommu_domain *domain, unsigned long iova, struct pt_iommu *pt = iommupt_from_domain(domain); size_t unmapped; + if (WARN_ON_ONCE(!(domain->type & __IOMMU_DOMAIN_PAGING) || + !domain->pgsize_bitmap)) + return 0; + + iommu_debug_unmap_begin(domain, iova, size); + if (pt) unmapped = pt->ops->unmap_range(pt, iova, size, iotlb_gather); else