"Fossies" - the Fresh Open Source Software Archive

Member "AdGuardHome-0.104.3/internal/dnsforward/dnsforward_test.go" (19 Nov 2020, 30696 Bytes) of package /linux/misc/dns/AdGuardHome-0.104.3.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Go source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. See also the latest Fossies "Diffs" side-by-side code changes report for "dnsforward_test.go": 0.104.1_vs_0.104.3.

    1 package dnsforward
    2 
    3 import (
    4     "crypto/ecdsa"
    5     "crypto/rand"
    6     "crypto/rsa"
    7     "crypto/tls"
    8     "crypto/x509"
    9     "crypto/x509/pkix"
   10     "encoding/pem"
   11     "fmt"
   12     "io/ioutil"
   13     "math/big"
   14     "net"
   15     "os"
   16     "sort"
   17     "sync"
   18     "testing"
   19     "time"
   20 
   21     "github.com/AdguardTeam/AdGuardHome/internal/testutil"
   22     "github.com/AdguardTeam/AdGuardHome/internal/util"
   23 
   24     "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
   25     "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
   26     "github.com/AdguardTeam/dnsproxy/proxy"
   27     "github.com/AdguardTeam/dnsproxy/upstream"
   28     "github.com/miekg/dns"
   29     "github.com/stretchr/testify/assert"
   30 )
   31 
   32 func TestMain(m *testing.M) {
   33     testutil.DiscardLogOutput(m)
   34 }
   35 
   36 const (
   37     tlsServerName     = "testdns.adguard.com"
   38     testMessagesCount = 10
   39 )
   40 
   41 func TestServer(t *testing.T) {
   42     s := createTestServer(t)
   43     err := s.Start()
   44     if err != nil {
   45         t.Fatalf("Failed to start server: %s", err)
   46     }
   47 
   48     // message over UDP
   49     req := createGoogleATestMessage()
   50     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
   51     client := dns.Client{Net: "udp"}
   52     reply, _, err := client.Exchange(req, addr.String())
   53     if err != nil {
   54         t.Fatalf("Couldn't talk to server %s: %s", addr, err)
   55     }
   56     assertGoogleAResponse(t, reply)
   57 
   58     // message over TCP
   59     req = createGoogleATestMessage()
   60     addr = s.dnsProxy.Addr("tcp")
   61     client = dns.Client{Net: "tcp"}
   62     reply, _, err = client.Exchange(req, addr.String())
   63     if err != nil {
   64         t.Fatalf("Couldn't talk to server %s: %s", addr, err)
   65     }
   66     assertGoogleAResponse(t, reply)
   67 
   68     err = s.Stop()
   69     if err != nil {
   70         t.Fatalf("DNS server failed to stop: %s", err)
   71     }
   72 }
   73 
   74 func TestServerWithProtectionDisabled(t *testing.T) {
   75     s := createTestServer(t)
   76     s.conf.ProtectionEnabled = false
   77     err := s.Start()
   78     if err != nil {
   79         t.Fatalf("Failed to start server: %s", err)
   80     }
   81 
   82     // message over UDP
   83     req := createGoogleATestMessage()
   84     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
   85     client := dns.Client{Net: "udp"}
   86     reply, _, err := client.Exchange(req, addr.String())
   87     if err != nil {
   88         t.Fatalf("Couldn't talk to server %s: %s", addr, err)
   89     }
   90     assertGoogleAResponse(t, reply)
   91 
   92     err = s.Stop()
   93     if err != nil {
   94         t.Fatalf("DNS server failed to stop: %s", err)
   95     }
   96 }
   97 
   98 func TestDotServer(t *testing.T) {
   99     // Prepare the proxy server
  100     _, certPem, keyPem := createServerTLSConfig(t)
  101     s := createTestServer(t)
  102 
  103     s.conf.TLSConfig = TLSConfig{
  104         TLSListenAddr:        &net.TCPAddr{Port: 0},
  105         CertificateChainData: certPem,
  106         PrivateKeyData:       keyPem,
  107     }
  108 
  109     _ = s.Prepare(nil)
  110     // Starting the server
  111     err := s.Start()
  112     if err != nil {
  113         t.Fatalf("Failed to start server: %s", err)
  114     }
  115 
  116     // Add our self-signed generated config to roots
  117     roots := x509.NewCertPool()
  118     roots.AppendCertsFromPEM(certPem)
  119     tlsConfig := &tls.Config{
  120         ServerName: tlsServerName,
  121         RootCAs:    roots,
  122         MinVersion: tls.VersionTLS12,
  123     }
  124 
  125     // Create a DNS-over-TLS client connection
  126     addr := s.dnsProxy.Addr(proxy.ProtoTLS)
  127     conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
  128     if err != nil {
  129         t.Fatalf("cannot connect to the proxy: %s", err)
  130     }
  131 
  132     sendTestMessages(t, conn)
  133 
  134     // Stop the proxy
  135     err = s.Stop()
  136     if err != nil {
  137         t.Fatalf("DNS server failed to stop: %s", err)
  138     }
  139 }
  140 
  141 func TestDoqServer(t *testing.T) {
  142     // Prepare the proxy server
  143     _, certPem, keyPem := createServerTLSConfig(t)
  144     s := createTestServer(t)
  145 
  146     s.conf.TLSConfig = TLSConfig{
  147         QUICListenAddr:       &net.UDPAddr{Port: 0},
  148         CertificateChainData: certPem,
  149         PrivateKeyData:       keyPem,
  150     }
  151 
  152     _ = s.Prepare(nil)
  153     // Starting the server
  154     err := s.Start()
  155     assert.Nil(t, err)
  156 
  157     // Create a DNS-over-QUIC upstream
  158     addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
  159     opts := upstream.Options{InsecureSkipVerify: true}
  160     u, err := upstream.AddressToUpstream(fmt.Sprintf("quic://%s", addr), opts)
  161     assert.Nil(t, err)
  162 
  163     // Send the test message
  164     req := createGoogleATestMessage()
  165     res, err := u.Exchange(req)
  166     assert.Nil(t, err)
  167     assertGoogleAResponse(t, res)
  168 
  169     // Stop the proxy
  170     err = s.Stop()
  171     if err != nil {
  172         t.Fatalf("DNS server failed to stop: %s", err)
  173     }
  174 }
  175 
  176 func TestServerRace(t *testing.T) {
  177     s := createTestServer(t)
  178     err := s.Start()
  179     if err != nil {
  180         t.Fatalf("Failed to start server: %s", err)
  181     }
  182 
  183     // message over UDP
  184     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  185     conn, err := dns.Dial("udp", addr.String())
  186     if err != nil {
  187         t.Fatalf("cannot connect to the proxy: %s", err)
  188     }
  189 
  190     sendTestMessagesAsync(t, conn)
  191 
  192     // Stop the proxy
  193     err = s.Stop()
  194     if err != nil {
  195         t.Fatalf("DNS server failed to stop: %s", err)
  196     }
  197 }
  198 
  199 func TestSafeSearch(t *testing.T) {
  200     s := createTestServer(t)
  201     err := s.Start()
  202     if err != nil {
  203         t.Fatalf("Failed to start server: %s", err)
  204     }
  205 
  206     // Test safe search for yandex. We already know safe search ip
  207     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  208     client := dns.Client{Net: "udp"}
  209     yandexDomains := []string{"yandex.com.", "yandex.by.", "yandex.kz.", "yandex.ru.", "yandex.com."}
  210     for _, host := range yandexDomains {
  211         exchangeAndAssertResponse(t, &client, addr, host, "213.180.193.56")
  212     }
  213 
  214     // Let's lookup for google safesearch ip
  215     ips, err := net.LookupIP("forcesafesearch.google.com")
  216     if err != nil {
  217         t.Fatalf("Failed to lookup for forcesafesearch.google.com: %s", err)
  218     }
  219 
  220     ip := ips[0]
  221     for _, i := range ips {
  222         if i.To4() != nil {
  223             ip = i
  224             break
  225         }
  226     }
  227 
  228     // Test safe search for google.
  229     googleDomains := []string{"www.google.com.", "www.google.com.af.", "www.google.be.", "www.google.by."}
  230     for _, host := range googleDomains {
  231         exchangeAndAssertResponse(t, &client, addr, host, ip.String())
  232     }
  233 
  234     err = s.Stop()
  235     if err != nil {
  236         t.Fatalf("Can not stopd server cause: %s", err)
  237     }
  238 }
  239 
  240 func TestInvalidRequest(t *testing.T) {
  241     s := createTestServer(t)
  242     err := s.Start()
  243     if err != nil {
  244         t.Fatalf("Failed to start server: %s", err)
  245     }
  246 
  247     // server is running, send a message
  248     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  249     req := dns.Msg{}
  250     req.Id = dns.Id()
  251     req.RecursionDesired = true
  252 
  253     // send a DNS request without question
  254     client := dns.Client{Net: "udp", Timeout: 500 * time.Millisecond}
  255     _, _, err = client.Exchange(&req, addr.String())
  256     if err != nil {
  257         t.Fatalf("got a response to an invalid query")
  258     }
  259 
  260     err = s.Stop()
  261     if err != nil {
  262         t.Fatalf("DNS server failed to stop: %s", err)
  263     }
  264 }
  265 
  266 func TestBlockedRequest(t *testing.T) {
  267     s := createTestServer(t)
  268     err := s.Start()
  269     if err != nil {
  270         t.Fatalf("Failed to start server: %s", err)
  271     }
  272     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  273 
  274     //
  275     // Default blocking - NULL IP
  276     //
  277     req := dns.Msg{}
  278     req.Id = dns.Id()
  279     req.RecursionDesired = true
  280     req.Question = []dns.Question{
  281         {Name: "nxdomain.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
  282     }
  283 
  284     reply, err := dns.Exchange(&req, addr.String())
  285     if err != nil {
  286         t.Fatalf("Couldn't talk to server %s: %s", addr, err)
  287     }
  288     assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
  289     assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0")))
  290 
  291     err = s.Stop()
  292     if err != nil {
  293         t.Fatalf("DNS server failed to stop: %s", err)
  294     }
  295 }
  296 
  297 func TestServerCustomClientUpstream(t *testing.T) {
  298     s := createTestServer(t)
  299     s.conf.GetCustomUpstreamByClient = func(clientAddr string) *proxy.UpstreamConfig {
  300         uc := &proxy.UpstreamConfig{}
  301         u := &testUpstream{}
  302         u.ipv4 = map[string][]net.IP{}
  303         u.ipv4["host."] = []net.IP{net.ParseIP("192.168.0.1")}
  304         uc.Upstreams = append(uc.Upstreams, u)
  305         return uc
  306     }
  307 
  308     assert.Nil(t, s.Start())
  309 
  310     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  311 
  312     // Send test request
  313     req := dns.Msg{}
  314     req.Id = dns.Id()
  315     req.RecursionDesired = true
  316     req.Question = []dns.Question{
  317         {Name: "host.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
  318     }
  319 
  320     reply, err := dns.Exchange(&req, addr.String())
  321 
  322     assert.Nil(t, err)
  323     assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
  324     assert.NotNil(t, reply.Answer)
  325     assert.Equal(t, "192.168.0.1", reply.Answer[0].(*dns.A).A.String())
  326     assert.Nil(t, s.Stop())
  327 }
  328 
  329 // testUpstream is a mock of real upstream.
  330 // specify fields with necessary values to simulate real upstream behaviour
  331 type testUpstream struct {
  332     cn   map[string]string   // Map of [name]canonical_name
  333     ipv4 map[string][]net.IP // Map of [name]IPv4
  334     ipv6 map[string][]net.IP // Map of [name]IPv6
  335 }
  336 
  337 func (u *testUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) {
  338     resp := dns.Msg{}
  339     resp.SetReply(m)
  340     hasARecord := false
  341     hasAAAARecord := false
  342 
  343     reqType := m.Question[0].Qtype
  344     name := m.Question[0].Name
  345 
  346     // Let's check if we have any CNAME for given name
  347     if cname, ok := u.cn[name]; ok {
  348         cn := dns.CNAME{}
  349         cn.Hdr.Name = name
  350         cn.Hdr.Rrtype = dns.TypeCNAME
  351         cn.Target = cname
  352         resp.Answer = append(resp.Answer, &cn)
  353     }
  354 
  355     // Let's check if we can add some A records to the answer
  356     if ipv4addr, ok := u.ipv4[name]; ok && reqType == dns.TypeA {
  357         hasARecord = true
  358         for _, ipv4 := range ipv4addr {
  359             respA := dns.A{}
  360             respA.Hdr.Rrtype = dns.TypeA
  361             respA.Hdr.Name = name
  362             respA.A = ipv4
  363             resp.Answer = append(resp.Answer, &respA)
  364         }
  365     }
  366 
  367     // Let's check if we can add some AAAA records to the answer
  368     if u.ipv6 != nil {
  369         if ipv6addr, ok := u.ipv6[name]; ok && reqType == dns.TypeAAAA {
  370             hasAAAARecord = true
  371             for _, ipv6 := range ipv6addr {
  372                 respAAAA := dns.A{}
  373                 respAAAA.Hdr.Rrtype = dns.TypeAAAA
  374                 respAAAA.Hdr.Name = name
  375                 respAAAA.A = ipv6
  376                 resp.Answer = append(resp.Answer, &respAAAA)
  377             }
  378         }
  379     }
  380 
  381     if len(resp.Answer) == 0 {
  382         if hasARecord || hasAAAARecord {
  383             // Set No Error RCode if there are some records for given Qname but we didn't apply them
  384             resp.SetRcode(m, dns.RcodeSuccess)
  385         } else {
  386             // Set NXDomain RCode otherwise
  387             resp.SetRcode(m, dns.RcodeNameError)
  388         }
  389     }
  390 
  391     return &resp, nil
  392 }
  393 
  394 func (u *testUpstream) Address() string {
  395     return "test"
  396 }
  397 
  398 func (s *Server) startWithUpstream(u upstream.Upstream) error {
  399     s.Lock()
  400     defer s.Unlock()
  401     err := s.Prepare(nil)
  402     if err != nil {
  403         return err
  404     }
  405     s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
  406         Upstreams: []upstream.Upstream{u},
  407     }
  408     return s.dnsProxy.Start()
  409 }
  410 
  411 // testCNAMEs is a simple map of names and CNAMEs necessary for the testUpstream work
  412 var testCNAMEs = map[string]string{
  413     "badhost.":               "null.example.org.",
  414     "whitelist.example.org.": "null.example.org.",
  415 }
  416 
  417 // testIPv4 is a simple map of names and IPv4s necessary for the testUpstream work
  418 var testIPv4 = map[string][]net.IP{
  419     "null.example.org.": {{1, 2, 3, 4}},
  420     "example.org.":      {{127, 0, 0, 255}},
  421 }
  422 
  423 func TestBlockCNAMEProtectionEnabled(t *testing.T) {
  424     s := createTestServer(t)
  425     testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
  426     s.conf.ProtectionEnabled = false
  427     err := s.startWithUpstream(testUpstm)
  428     assert.True(t, err == nil)
  429     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  430 
  431     // 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
  432     // but protection is disabled - response is NOT blocked
  433     req := createTestMessage("badhost.")
  434     reply, err := dns.Exchange(req, addr.String())
  435     assert.Nil(t, err)
  436     assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
  437 }
  438 
  439 func TestBlockCNAME(t *testing.T) {
  440     s := createTestServer(t)
  441     testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
  442     err := s.startWithUpstream(testUpstm)
  443     assert.True(t, err == nil)
  444     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  445 
  446     // 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
  447     // response is blocked
  448     req := createTestMessage("badhost.")
  449     reply, err := dns.Exchange(req, addr.String())
  450     assert.Nil(t, err, nil)
  451     assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
  452     assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0")))
  453 
  454     // 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters
  455     //   but 'whitelist.example.org' is in a whitelist:
  456     // response isn't blocked
  457     req = createTestMessage("whitelist.example.org.")
  458     reply, err = dns.Exchange(req, addr.String())
  459     assert.Nil(t, err)
  460     assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
  461 
  462     // 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters:
  463     // response is blocked
  464     req = createTestMessage("example.org.")
  465     reply, err = dns.Exchange(req, addr.String())
  466     assert.Nil(t, err)
  467     assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
  468     assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0")))
  469 
  470     _ = s.Stop()
  471 }
  472 
  473 func TestClientRulesForCNAMEMatching(t *testing.T) {
  474     s := createTestServer(t)
  475     testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
  476     s.conf.FilterHandler = func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) {
  477         settings.FilteringEnabled = false
  478     }
  479     err := s.startWithUpstream(testUpstm)
  480     assert.Nil(t, err)
  481     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  482 
  483     // 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
  484     // response is blocked
  485     req := dns.Msg{}
  486     req.Id = dns.Id()
  487     req.Question = []dns.Question{
  488         {Name: "badhost.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
  489     }
  490     // However, in our case it should not be blocked
  491     // as filtering is disabled on the client level
  492     reply, err := dns.Exchange(&req, addr.String())
  493     assert.Nil(t, err)
  494     assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
  495 }
  496 
  497 func TestNullBlockedRequest(t *testing.T) {
  498     s := createTestServer(t)
  499     s.conf.FilteringConfig.BlockingMode = "null_ip"
  500     err := s.Start()
  501     if err != nil {
  502         t.Fatalf("Failed to start server: %s", err)
  503     }
  504     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  505 
  506     //
  507     // Null filter blocking
  508     //
  509     req := dns.Msg{}
  510     req.Id = dns.Id()
  511     req.RecursionDesired = true
  512     req.Question = []dns.Question{
  513         {Name: "null.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
  514     }
  515 
  516     reply, err := dns.Exchange(&req, addr.String())
  517     if err != nil {
  518         t.Fatalf("Couldn't talk to server %s: %s", addr, err)
  519     }
  520     if len(reply.Answer) != 1 {
  521         t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
  522     }
  523     if a, ok := reply.Answer[0].(*dns.A); ok {
  524         if !net.IPv4zero.Equal(a.A) {
  525             t.Fatalf("DNS server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A)
  526         }
  527     } else {
  528         t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
  529     }
  530 
  531     err = s.Stop()
  532     if err != nil {
  533         t.Fatalf("DNS server failed to stop: %s", err)
  534     }
  535 }
  536 
  537 func TestBlockedCustomIP(t *testing.T) {
  538     rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1   host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n"
  539     filters := []dnsfilter.Filter{{
  540         ID: 0, Data: []byte(rules),
  541     }}
  542     c := dnsfilter.Config{}
  543 
  544     f := dnsfilter.New(&c, filters)
  545     s := NewServer(DNSCreateParams{DNSFilter: f})
  546     conf := ServerConfig{}
  547     conf.UDPListenAddr = &net.UDPAddr{Port: 0}
  548     conf.TCPListenAddr = &net.TCPAddr{Port: 0}
  549     conf.ProtectionEnabled = true
  550     conf.BlockingMode = "custom_ip"
  551     conf.BlockingIPv4 = "bad IP"
  552     conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"}
  553     err := s.Prepare(&conf)
  554     assert.True(t, err != nil) // invalid BlockingIPv4
  555 
  556     conf.BlockingIPv4 = "0.0.0.1"
  557     conf.BlockingIPv6 = "::1"
  558     err = s.Prepare(&conf)
  559     assert.Nil(t, err)
  560     err = s.Start()
  561     assert.Nil(t, err)
  562 
  563     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  564 
  565     req := createTestMessageWithType("null.example.org.", dns.TypeA)
  566     reply, err := dns.Exchange(req, addr.String())
  567     assert.Nil(t, err)
  568     assert.Equal(t, 1, len(reply.Answer))
  569     a, ok := reply.Answer[0].(*dns.A)
  570     assert.True(t, ok)
  571     assert.Equal(t, "0.0.0.1", a.A.String())
  572 
  573     req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
  574     reply, err = dns.Exchange(req, addr.String())
  575     assert.Nil(t, err)
  576     assert.Equal(t, 1, len(reply.Answer))
  577     a6, ok := reply.Answer[0].(*dns.AAAA)
  578     assert.True(t, ok)
  579     assert.Equal(t, "::1", a6.AAAA.String())
  580 
  581     err = s.Stop()
  582     if err != nil {
  583         t.Fatalf("DNS server failed to stop: %s", err)
  584     }
  585 }
  586 
  587 func TestBlockedByHosts(t *testing.T) {
  588     s := createTestServer(t)
  589     err := s.Start()
  590     if err != nil {
  591         t.Fatalf("Failed to start server: %s", err)
  592     }
  593     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  594 
  595     //
  596     // Hosts blocking
  597     //
  598     req := dns.Msg{}
  599     req.Id = dns.Id()
  600     req.RecursionDesired = true
  601     req.Question = []dns.Question{
  602         {Name: "host.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
  603     }
  604 
  605     reply, err := dns.Exchange(&req, addr.String())
  606     if err != nil {
  607         t.Fatalf("Couldn't talk to server %s: %s", addr, err)
  608     }
  609     if len(reply.Answer) != 1 {
  610         t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
  611     }
  612     if a, ok := reply.Answer[0].(*dns.A); ok {
  613         if !net.IPv4(127, 0, 0, 1).Equal(a.A) {
  614             t.Fatalf("DNS server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
  615         }
  616     } else {
  617         t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
  618     }
  619 
  620     err = s.Stop()
  621     if err != nil {
  622         t.Fatalf("DNS server failed to stop: %s", err)
  623     }
  624 }
  625 
  626 func TestBlockedBySafeBrowsing(t *testing.T) {
  627     s := createTestServer(t)
  628     err := s.Start()
  629     if err != nil {
  630         t.Fatalf("Failed to start server: %s", err)
  631     }
  632     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  633 
  634     //
  635     // Safebrowsing blocking
  636     //
  637     req := dns.Msg{}
  638     req.Id = dns.Id()
  639     req.RecursionDesired = true
  640     req.Question = []dns.Question{
  641         {Name: "wmconvirus.narod.ru.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
  642     }
  643     reply, err := dns.Exchange(&req, addr.String())
  644     if err != nil {
  645         t.Fatalf("Couldn't talk to server %s: %s", addr, err)
  646     }
  647     if len(reply.Answer) != 1 {
  648         t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
  649     }
  650     if a, ok := reply.Answer[0].(*dns.A); ok {
  651         addrs, lookupErr := net.LookupHost(safeBrowsingBlockHost)
  652         if lookupErr != nil {
  653             t.Fatalf("cannot resolve %s due to %s", safeBrowsingBlockHost, lookupErr)
  654         }
  655 
  656         found := false
  657         for _, blockAddr := range addrs {
  658             if blockAddr == a.A.String() {
  659                 found = true
  660             }
  661         }
  662 
  663         if !found {
  664             t.Fatalf("DNS server %s returned wrong answer: %v", addr, a.A)
  665         }
  666     } else {
  667         t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
  668     }
  669 
  670     err = s.Stop()
  671     if err != nil {
  672         t.Fatalf("DNS server failed to stop: %s", err)
  673     }
  674 }
  675 
  676 func TestRewrite(t *testing.T) {
  677     c := dnsfilter.Config{}
  678     c.Rewrites = []dnsfilter.RewriteEntry{
  679         {
  680             Domain: "test.com",
  681             Answer: "1.2.3.4",
  682             Type:   dns.TypeA,
  683         },
  684         {
  685             Domain: "alias.test.com",
  686             Answer: "test.com",
  687             Type:   dns.TypeCNAME,
  688         },
  689         {
  690             Domain: "my.alias.example.org",
  691             Answer: "example.org",
  692             Type:   dns.TypeCNAME,
  693         },
  694     }
  695 
  696     f := dnsfilter.New(&c, nil)
  697     s := NewServer(DNSCreateParams{DNSFilter: f})
  698     conf := ServerConfig{}
  699     conf.UDPListenAddr = &net.UDPAddr{Port: 0}
  700     conf.TCPListenAddr = &net.TCPAddr{Port: 0}
  701     conf.ProtectionEnabled = true
  702     conf.UpstreamDNS = []string{"8.8.8.8:53"}
  703 
  704     err := s.Prepare(&conf)
  705     assert.Nil(t, err)
  706     err = s.Start()
  707     assert.Nil(t, err)
  708     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
  709 
  710     req := createTestMessageWithType("test.com.", dns.TypeA)
  711     reply, err := dns.Exchange(req, addr.String())
  712     assert.Nil(t, err)
  713     assert.Equal(t, 1, len(reply.Answer))
  714     a, ok := reply.Answer[0].(*dns.A)
  715     assert.True(t, ok)
  716     assert.Equal(t, "1.2.3.4", a.A.String())
  717 
  718     req = createTestMessageWithType("test.com.", dns.TypeAAAA)
  719     reply, err = dns.Exchange(req, addr.String())
  720     assert.Nil(t, err)
  721     assert.Equal(t, 0, len(reply.Answer))
  722 
  723     req = createTestMessageWithType("alias.test.com.", dns.TypeA)
  724     reply, err = dns.Exchange(req, addr.String())
  725     assert.Nil(t, err)
  726     assert.Equal(t, 2, len(reply.Answer))
  727     assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
  728     assert.Equal(t, "1.2.3.4", reply.Answer[1].(*dns.A).A.String())
  729 
  730     req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
  731     reply, err = dns.Exchange(req, addr.String())
  732     assert.Nil(t, err)
  733     assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) // the original question is restored
  734     assert.Equal(t, 2, len(reply.Answer))
  735     assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
  736     assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
  737 
  738     _ = s.Stop()
  739 }
  740 
  741 func createTestServer(t *testing.T) *Server {
  742     rules := `||nxdomain.example.org
  743 ||null.example.org^
  744 127.0.0.1   host.example.org
  745 @@||whitelist.example.org^
  746 ||127.0.0.255`
  747     filters := []dnsfilter.Filter{{
  748         ID: 0, Data: []byte(rules),
  749     }}
  750     c := dnsfilter.Config{}
  751     c.SafeBrowsingEnabled = true
  752     c.SafeBrowsingCacheSize = 1000
  753     c.SafeSearchEnabled = true
  754     c.SafeSearchCacheSize = 1000
  755     c.ParentalCacheSize = 1000
  756     c.CacheTime = 30
  757 
  758     f := dnsfilter.New(&c, filters)
  759 
  760     s := NewServer(DNSCreateParams{DNSFilter: f})
  761     s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
  762     s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
  763     s.conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"}
  764     s.conf.FilteringConfig.ProtectionEnabled = true
  765     s.conf.ConfigModified = func() {}
  766 
  767     err := s.Prepare(nil)
  768     assert.True(t, err == nil)
  769     return s
  770 }
  771 
  772 func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
  773     privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
  774     if err != nil {
  775         t.Fatalf("cannot generate RSA key: %s", err)
  776     }
  777 
  778     serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
  779     serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
  780     if err != nil {
  781         t.Fatalf("failed to generate serial number: %s", err)
  782     }
  783 
  784     notBefore := time.Now()
  785     notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
  786 
  787     template := x509.Certificate{
  788         SerialNumber: serialNumber,
  789         Subject: pkix.Name{
  790             Organization: []string{"AdGuard Tests"},
  791         },
  792         NotBefore: notBefore,
  793         NotAfter:  notAfter,
  794 
  795         KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
  796         ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
  797         BasicConstraintsValid: true,
  798         IsCA:                  true,
  799     }
  800     template.DNSNames = append(template.DNSNames, tlsServerName)
  801 
  802     derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
  803     if err != nil {
  804         t.Fatalf("failed to create certificate: %s", err)
  805     }
  806 
  807     certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
  808     keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
  809 
  810     cert, err := tls.X509KeyPair(certPem, keyPem)
  811     if err != nil {
  812         t.Fatalf("failed to create certificate: %s", err)
  813     }
  814 
  815     return &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: tlsServerName, MinVersion: tls.VersionTLS12}, certPem, keyPem
  816 }
  817 
  818 func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) {
  819     defer func() {
  820         g.Done()
  821     }()
  822 
  823     req := createGoogleATestMessage()
  824     err := conn.WriteMsg(req)
  825     if err != nil {
  826         panic(fmt.Sprintf("cannot write message: %s", err))
  827     }
  828 
  829     res, err := conn.ReadMsg()
  830     if err != nil {
  831         panic(fmt.Sprintf("cannot read response to message: %s", err))
  832     }
  833     assertGoogleAResponse(t, res)
  834 }
  835 
  836 // sendTestMessagesAsync sends messages in parallel
  837 // so that we could find race issues
  838 func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
  839     g := &sync.WaitGroup{}
  840     g.Add(testMessagesCount)
  841 
  842     for i := 0; i < testMessagesCount; i++ {
  843         go sendTestMessageAsync(t, conn, g)
  844     }
  845 
  846     g.Wait()
  847 }
  848 
  849 func sendTestMessages(t *testing.T, conn *dns.Conn) {
  850     for i := 0; i < 10; i++ {
  851         req := createGoogleATestMessage()
  852         err := conn.WriteMsg(req)
  853         if err != nil {
  854             t.Fatalf("cannot write message #%d: %s", i, err)
  855         }
  856 
  857         res, err := conn.ReadMsg()
  858         if err != nil {
  859             t.Fatalf("cannot read response to message #%d: %s", i, err)
  860         }
  861         assertGoogleAResponse(t, res)
  862     }
  863 }
  864 
  865 func exchangeAndAssertResponse(t *testing.T, client *dns.Client, addr net.Addr, host, ip string) {
  866     req := createTestMessage(host)
  867     reply, _, err := client.Exchange(req, addr.String())
  868     if err != nil {
  869         t.Fatalf("Couldn't talk to server %s: %s", addr, err)
  870     }
  871     assertResponse(t, reply, ip)
  872 }
  873 
  874 func createGoogleATestMessage() *dns.Msg {
  875     return createTestMessage("google-public-dns-a.google.com.")
  876 }
  877 
  878 func createTestMessage(host string) *dns.Msg {
  879     req := dns.Msg{}
  880     req.Id = dns.Id()
  881     req.RecursionDesired = true
  882     req.Question = []dns.Question{
  883         {Name: host, Qtype: dns.TypeA, Qclass: dns.ClassINET},
  884     }
  885     return &req
  886 }
  887 
  888 func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
  889     req := dns.Msg{}
  890     req.Id = dns.Id()
  891     req.RecursionDesired = true
  892     req.Question = []dns.Question{
  893         {Name: host, Qtype: qtype, Qclass: dns.ClassINET},
  894     }
  895     return &req
  896 }
  897 
  898 func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
  899     assertResponse(t, reply, "8.8.8.8")
  900 }
  901 
  902 func assertResponse(t *testing.T, reply *dns.Msg, ip string) {
  903     if len(reply.Answer) != 1 {
  904         t.Fatalf("DNS server returned reply with wrong number of answers - %d", len(reply.Answer))
  905     }
  906     if a, ok := reply.Answer[0].(*dns.A); ok {
  907         if !net.ParseIP(ip).Equal(a.A) {
  908             t.Fatalf("DNS server returned wrong answer instead of %s: %v", ip, a.A)
  909         }
  910     } else {
  911         t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0])
  912     }
  913 }
  914 
  915 func publicKey(priv interface{}) interface{} {
  916     switch k := priv.(type) {
  917     case *rsa.PrivateKey:
  918         return &k.PublicKey
  919     case *ecdsa.PrivateKey:
  920         return &k.PublicKey
  921     default:
  922         return nil
  923     }
  924 }
  925 
  926 func TestValidateUpstream(t *testing.T) {
  927     invalidUpstreams := []string{
  928         "1.2.3.4.5",
  929         "123.3.7m",
  930         "htttps://google.com/dns-query",
  931         "[/host.com]tls://dns.adguard.com",
  932         "[host.ru]#",
  933     }
  934 
  935     validDefaultUpstreams := []string{
  936         "1.1.1.1",
  937         "tls://1.1.1.1",
  938         "https://dns.adguard.com/dns-query",
  939         "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
  940     }
  941 
  942     validUpstreams := []string{
  943         "[/host.com/]1.1.1.1",
  944         "[//]tls://1.1.1.1",
  945         "[/www.host.com/]#",
  946         "[/host.com/google.com/]8.8.8.8",
  947         "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
  948     }
  949     for _, u := range invalidUpstreams {
  950         _, err := validateUpstream(u)
  951         if err == nil {
  952             t.Fatalf("upstream %s is invalid but it pass through validation", u)
  953         }
  954     }
  955 
  956     for _, u := range validDefaultUpstreams {
  957         defaultUpstream, err := validateUpstream(u)
  958         if err != nil {
  959             t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err)
  960         }
  961         if !defaultUpstream {
  962             t.Fatalf("upstream %s is default one!", u)
  963         }
  964     }
  965 
  966     for _, u := range validUpstreams {
  967         defaultUpstream, err := validateUpstream(u)
  968         if err != nil {
  969             t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err)
  970         }
  971         if defaultUpstream {
  972             t.Fatalf("upstream %s is default one!", u)
  973         }
  974     }
  975 }
  976 
  977 func TestValidateUpstreamsSet(t *testing.T) {
  978     // Empty upstreams array
  979     var upstreamsSet []string
  980     err := ValidateUpstreams(upstreamsSet)
  981     assert.Nil(t, err, "empty upstreams array should be valid")
  982 
  983     // Comment in upstreams array
  984     upstreamsSet = []string{"# comment"}
  985     err = ValidateUpstreams(upstreamsSet)
  986     assert.Nil(t, err, "comments should not be validated")
  987 
  988     // Set of valid upstreams. There is no default upstream specified
  989     upstreamsSet = []string{
  990         "[/host.com/]1.1.1.1",
  991         "[//]tls://1.1.1.1",
  992         "[/www.host.com/]#",
  993         "[/host.com/google.com/]8.8.8.8",
  994         "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
  995     }
  996     err = ValidateUpstreams(upstreamsSet)
  997     assert.NotNil(t, err, "there is no default upstream")
  998 
  999     // Let's add default upstream
 1000     upstreamsSet = append(upstreamsSet, "8.8.8.8")
 1001     err = ValidateUpstreams(upstreamsSet)
 1002     assert.Nilf(t, err, "upstreams set is valid, but doesn't pass through validation cause: %s", err)
 1003 
 1004     // Let's add invalid upstream
 1005     upstreamsSet = append(upstreamsSet, "dhcp://fake.dns")
 1006     err = ValidateUpstreams(upstreamsSet)
 1007     assert.NotNil(t, err, "there is an invalid upstream in set, but it pass through validation")
 1008 }
 1009 
 1010 func TestIpFromAddr(t *testing.T) {
 1011     addr := net.UDPAddr{}
 1012     addr.IP = net.ParseIP("1:2:3::4")
 1013     addr.Port = 12345
 1014     addr.Zone = "eth0"
 1015     a := ipFromAddr(&addr)
 1016     assert.True(t, a == "1:2:3::4")
 1017 
 1018     a = ipFromAddr(nil)
 1019     assert.True(t, a == "")
 1020 }
 1021 
 1022 func TestMatchDNSName(t *testing.T) {
 1023     dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
 1024     sort.Strings(dnsNames)
 1025     assert.True(t, matchDNSName(dnsNames, "host1"))
 1026     assert.True(t, matchDNSName(dnsNames, "a.host2"))
 1027     assert.True(t, matchDNSName(dnsNames, "b.a.host2"))
 1028     assert.True(t, matchDNSName(dnsNames, "1.2.3.4"))
 1029     assert.True(t, !matchDNSName(dnsNames, "host2"))
 1030     assert.True(t, !matchDNSName(dnsNames, ""))
 1031     assert.True(t, !matchDNSName(dnsNames, "*.host2"))
 1032 }
 1033 
 1034 type testDHCP struct {
 1035 }
 1036 
 1037 func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
 1038     l := dhcpd.Lease{}
 1039     l.IP = net.ParseIP("127.0.0.1").To4()
 1040     l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
 1041     l.Hostname = "localhost"
 1042     return []dhcpd.Lease{l}
 1043 }
 1044 func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
 1045 
 1046 func TestPTRResponseFromDHCPLeases(t *testing.T) {
 1047     dhcp := &testDHCP{}
 1048 
 1049     c := dnsfilter.Config{}
 1050     f := dnsfilter.New(&c, nil)
 1051     s := NewServer(DNSCreateParams{DNSFilter: f, DHCPServer: dhcp})
 1052     s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
 1053     s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
 1054     s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
 1055     s.conf.FilteringConfig.ProtectionEnabled = true
 1056     err := s.Prepare(nil)
 1057     assert.True(t, err == nil)
 1058     assert.Nil(t, s.Start())
 1059 
 1060     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
 1061     req := createTestMessage("1.0.0.127.in-addr.arpa.")
 1062     req.Question[0].Qtype = dns.TypePTR
 1063 
 1064     resp, err := dns.Exchange(req, addr.String())
 1065     assert.Nil(t, err)
 1066     assert.Equal(t, 1, len(resp.Answer))
 1067     assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
 1068     assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
 1069     ptr := resp.Answer[0].(*dns.PTR)
 1070     assert.Equal(t, "localhost.", ptr.Ptr)
 1071 
 1072     s.Close()
 1073 }
 1074 
 1075 func TestPTRResponseFromHosts(t *testing.T) {
 1076     c := dnsfilter.Config{
 1077         AutoHosts: &util.AutoHosts{},
 1078     }
 1079 
 1080     // Prepare test hosts file
 1081     hf, _ := ioutil.TempFile("", "")
 1082     defer func() { _ = os.Remove(hf.Name()) }()
 1083     defer hf.Close()
 1084 
 1085     _, _ = hf.WriteString("  127.0.0.1   host # comment \n")
 1086     _, _ = hf.WriteString("  ::1   localhost#comment  \n")
 1087 
 1088     // Init auto hosts
 1089     c.AutoHosts.Init(hf.Name())
 1090     defer c.AutoHosts.Close()
 1091 
 1092     f := dnsfilter.New(&c, nil)
 1093     s := NewServer(DNSCreateParams{DNSFilter: f})
 1094     s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
 1095     s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
 1096     s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
 1097     s.conf.FilteringConfig.ProtectionEnabled = true
 1098     err := s.Prepare(nil)
 1099     assert.True(t, err == nil)
 1100     assert.Nil(t, s.Start())
 1101 
 1102     addr := s.dnsProxy.Addr(proxy.ProtoUDP)
 1103     req := createTestMessage("1.0.0.127.in-addr.arpa.")
 1104     req.Question[0].Qtype = dns.TypePTR
 1105 
 1106     resp, err := dns.Exchange(req, addr.String())
 1107     assert.Nil(t, err)
 1108     assert.Equal(t, 1, len(resp.Answer))
 1109     assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
 1110     assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
 1111     ptr := resp.Answer[0].(*dns.PTR)
 1112     assert.Equal(t, "host.", ptr.Ptr)
 1113 
 1114     s.Close()
 1115 }