diff --git a/include/net/udp.h b/include/net/udp.h
index 9e82cb391dea2b64fec14f8e6323c22d4c7b5c53..a496e441645e14baacd92826315b09509be0a3b1 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -252,6 +252,17 @@ static inline int udp_rqueue_get(struct sock *sk)
 	return sk_rmem_alloc_get(sk) - READ_ONCE(udp_sk(sk)->forward_deficit);
 }
 
+static inline bool udp_sk_bound_dev_eq(struct net *net, int bound_dev_if,
+				       int dif, int sdif)
+{
+#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
+	return inet_bound_dev_eq(!!net->ipv4.sysctl_udp_l3mdev_accept,
+				 bound_dev_if, dif, sdif);
+#else
+	return inet_bound_dev_eq(true, bound_dev_if, dif, sdif);
+#endif
+}
+
 /* net/ipv4/udp.c */
 void udp_destruct_sock(struct sock *sk);
 void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len);
diff --git a/net/core/sock.c b/net/core/sock.c
index 080a880a1761b8e0efafaddf0ddac5bb87c64f88..7b304e454a38599a5851ac2d201aa2f43b8265fe 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -567,6 +567,8 @@ static int sock_setbindtodevice(struct sock *sk, char __user *optval,
 
 	lock_sock(sk);
 	sk->sk_bound_dev_if = index;
+	if (sk->sk_prot->rehash)
+		sk->sk_prot->rehash(sk);
 	sk_dst_reset(sk);
 	release_sock(sk);
 
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 1976fddb9e00515072210c6bbcde929cb2832c73..cf73c9194bb6947632c4cb220ca472d821dc89ed 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -371,6 +371,7 @@ static int compute_score(struct sock *sk, struct net *net,
 {
 	int score;
 	struct inet_sock *inet;
+	bool dev_match;
 
 	if (!net_eq(sock_net(sk), net) ||
 	    udp_sk(sk)->udp_port_hash != hnum ||
@@ -398,15 +399,11 @@ static int compute_score(struct sock *sk, struct net *net,
 		score += 4;
 	}
 
-	if (sk->sk_bound_dev_if || exact_dif) {
-		bool dev_match = (sk->sk_bound_dev_if == dif ||
-				  sk->sk_bound_dev_if == sdif);
-
-		if (!dev_match)
-			return -1;
-		if (sk->sk_bound_dev_if)
-			score += 4;
-	}
+	dev_match = udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if,
+					dif, sdif);
+	if (!dev_match)
+		return -1;
+	score += 4;
 
 	if (sk->sk_incoming_cpu == raw_smp_processor_id())
 		score++;
diff --git a/net/ipv6/datagram.c b/net/ipv6/datagram.c
index 1ede7a16a0bec897a8e09b79915f16dbcd46cd2d..bde08aa549f38cf844227cad37bb2b5e67ddaf88 100644
--- a/net/ipv6/datagram.c
+++ b/net/ipv6/datagram.c
@@ -772,6 +772,7 @@ int ip6_datagram_send_ctl(struct net *net, struct sock *sk,
 		case IPV6_2292PKTINFO:
 		    {
 			struct net_device *dev = NULL;
+			int src_idx;
 
 			if (cmsg->cmsg_len < CMSG_LEN(sizeof(struct in6_pktinfo))) {
 				err = -EINVAL;
@@ -779,12 +780,15 @@ int ip6_datagram_send_ctl(struct net *net, struct sock *sk,
 			}
 
 			src_info = (struct in6_pktinfo *)CMSG_DATA(cmsg);
+			src_idx = src_info->ipi6_ifindex;
 
-			if (src_info->ipi6_ifindex) {
+			if (src_idx) {
 				if (fl6->flowi6_oif &&
-				    src_info->ipi6_ifindex != fl6->flowi6_oif)
+				    src_idx != fl6->flowi6_oif &&
+				    (sk->sk_bound_dev_if != fl6->flowi6_oif ||
+				     !sk_dev_equal_l3scope(sk, src_idx)))
 					return -EINVAL;
-				fl6->flowi6_oif = src_info->ipi6_ifindex;
+				fl6->flowi6_oif = src_idx;
 			}
 
 			addr_type = __ipv6_addr_type(&src_info->ipi6_addr);
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index d2d97d07ef27a4e75f236e6a530894c0b3609a58..0559adc2f357ef803b105ee9e14a6c2b61309cfa 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -117,6 +117,7 @@ static int compute_score(struct sock *sk, struct net *net,
 {
 	int score;
 	struct inet_sock *inet;
+	bool dev_match;
 
 	if (!net_eq(sock_net(sk), net) ||
 	    udp_sk(sk)->udp_port_hash != hnum ||
@@ -144,15 +145,10 @@ static int compute_score(struct sock *sk, struct net *net,
 		score++;
 	}
 
-	if (sk->sk_bound_dev_if || exact_dif) {
-		bool dev_match = (sk->sk_bound_dev_if == dif ||
-				  sk->sk_bound_dev_if == sdif);
-
-		if (!dev_match)
-			return -1;
-		if (sk->sk_bound_dev_if)
-			score++;
-	}
+	dev_match = udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif);
+	if (!dev_match)
+		return -1;
+	score++;
 
 	if (sk->sk_incoming_cpu == raw_smp_processor_id())
 		score++;