ao-usbload: Check target device name to avoid mis-flashing
authorKeith Packard <keithp@keithp.com>
Sun, 14 Aug 2016 22:57:39 +0000 (15:57 -0700)
committerKeith Packard <keithp@keithp.com>
Sun, 14 Aug 2016 22:57:39 +0000 (15:57 -0700)
Instead of blindly loading firmware, go get the old device's name and
make sure it matches the new firmware.

Add --force option to allow this to be overridden.

Signed-off-by: Keith Packard <keithp@keithp.com>
ao-tools/ao-usbload/ao-usbload.c
ao-tools/lib/ao-editaltos.c
ao-tools/lib/ao-editaltos.h
ao-tools/lib/ao-selfload.c
ao-tools/lib/ao-selfload.h

index 2f121cf..d0579de 100644 (file)
@@ -41,7 +41,6 @@ get_uint16(struct cc_usb *cc, uint32_t addr)
 {
        uint16_t        result;
        result = ao_self_get_uint16(cc, addr);
-       printf ("read 0x%08x = 0x%04x\n", addr, result);
        return result;
 }
 
@@ -55,7 +54,6 @@ get_uint32(struct cc_usb *cc, uint32_t addr)
        uint32_t        result;
 
        result = ao_self_get_uint32(cc, addr);
-       printf ("read 0x%08x = 0x%08x\n", addr, result);
        return result;
 }
 
@@ -88,12 +86,13 @@ static const struct option options[] = {
        { .name = "serial", .has_arg = 1, .val = 's' },
        { .name = "verbose", .has_arg = 1, .val = 'v' },
        { .name = "wait", .has_arg = 0, .val = 'w' },
+       { .name = "force", .has_arg = 0, .val = 'f' },
        { 0, 0, 0, 0},
 };
 
 static void usage(char *program)
 {
-       fprintf(stderr, "usage: %s [--raw] [--verbose=<verbose>] [--device=<device>] [-tty=<tty>] [--cal=<radio-cal>] [--serial=<serial>] [--wait] file.{elf,ihx}\n", program);
+       fprintf(stderr, "usage: %s [--raw] [--verbose=<verbose>] [--device=<device>] [-tty=<tty>] [--cal=<radio-cal>] [--serial=<serial>] [--wait] [--force] file.{elf,ihx}\n", program);
        exit(1);
 }
 
@@ -116,6 +115,48 @@ ends_with(char *whole, char *suffix)
        return strcmp(whole + whole_len - suffix_len, suffix) == 0;
 }
 
+static int
+ucs2len(uint16_t *ucs2)
+{
+       int     len = 0;
+       while (*ucs2++)
+               len++;
+       return len;
+}
+
+int
+putucs4(uint32_t c, FILE *file)
+{
+       char d;
+       int     bits;
+
+            if (c <       0x80) { d = c;                         bits= -6; }
+       else if (c <      0x800) { d= ((c >>  6) & 0x1F) | 0xC0;  bits=  0; }
+       else if (c <    0x10000) { d= ((c >> 12) & 0x0F) | 0xE0;  bits=  6; }
+       else if (c <   0x200000) { d= ((c >> 18) & 0x07) | 0xF0;  bits= 12; }
+       else if (c <  0x4000000) { d= ((c >> 24) & 0x03) | 0xF8;  bits= 18; }
+       else if (c < 0x80000000) { d= ((c >> 30) & 0x01) | 0xFC;  bits= 24; }
+       else return EOF;
+
+       if (putc (d, file) < 0)
+               return EOF;
+
+       for ( ; bits >= 0; bits-= 6)
+               if (putc (((c >> bits) & 0x3F) | 0x80, file) < 0)
+                       return EOF;
+
+       return 0;
+}
+
+static void
+putucs2str(uint16_t *ucs2str, FILE *file)
+{
+       uint16_t        ucs2;
+
+       while ((ucs2 = *ucs2str++) != 0)
+               putucs4(ucs2, file);
+}
+
 int
 main (int argc, char **argv)
 {
@@ -146,6 +187,7 @@ main (int argc, char **argv)
        int                     num_file_symbols;
        uint32_t                flash_base, flash_bound;
        int                     has_flash_size = 0;
+       int                     force = 0;
 
        while ((c = getopt_long(argc, argv, "wrT:D:c:s:v:", options, NULL)) != -1) {
                switch (c) {
@@ -174,6 +216,9 @@ main (int argc, char **argv)
                case 'v':
                        verbose++;
                        break;
+               case 'f':
+                       force = 1;
+                       break;
                default:
                        usage(argv[0]);
                        break;
@@ -323,9 +368,52 @@ main (int argc, char **argv)
                        }
                }
 
+               if (!force && was_flashed) {
+                       struct ao_usb_id        new_id, old_id;
+                       uint16_t                *new_product, *old_product;
+                       int                     new_len, old_len;
+
+                       if (!ao_heximage_usb_id(load, &new_id)) {
+                               fprintf(stderr, "Can't get new USB id\n");
+                               done(cc, 1);
+                       }
+
+                       if (!ao_self_get_usb_id(cc, &old_id)) {
+                               fprintf(stderr, "Can't get old USB id\n");
+                               done(cc, 1);
+                       }
+                       if (new_id.vid != old_id.vid || new_id.pid != old_id.pid) {
+                               fprintf(stderr, "USB ID mismatch (device is %04x/%04x image is %04x/%04x)\n",
+                                       old_id.vid, old_id.pid, new_id.vid, new_id.pid);
+                               done(cc, 1);
+                       }
+
+                       new_product = ao_heximage_usb_product(load);
+                       if (!new_product) {
+                               fprintf(stderr, "Can't get new USB product name\n");
+                               done(cc, 1);
+                       }
+                       old_product = ao_self_get_usb_product(cc);
+                       if (!old_product) {
+                               fprintf(stderr, "Can't get existing USB product name\n");
+                               done(cc, 1);
+                       }
+                       new_len = ucs2len(new_product);
+                       old_len = ucs2len(old_product);
+                       if (1 || new_len != old_len || memcmp(new_product, old_product, new_len * 2) != 0) {
+                               fprintf(stderr, "USB product mismatch (device is ");
+                               putucs2str(new_product, stderr);
+                               fprintf(stderr, ", image is ");
+                               putucs2str(old_product, stderr);
+                               fprintf(stderr, ")\n");
+                               done(cc, 1);
+                       }
+               }
+
                if (!ao_editaltos(load, serial, cal))
                        done(cc, 1);
        }
+       done(cc, 0);
 
        /* And flash the resulting image to the device
         */
index 2a52c15..0600965 100644 (file)
 #include "ao-editaltos.h"
 
 struct ao_sym ao_symbols[] = {
-       {
+       [AO_ROMCONFIG_VERSION_INDEX] = {
                .name = "ao_romconfig_version",
                .required = 1
        },
-       {
+       [AO_ROMCONFIG_CHECK_INDEX] = {
                .name = "ao_romconfig_check",
                .required = 1
        },
-       {
+       [AO_SERIAL_NUMBER_INDEX] = {
                .name = "ao_serial_number",
                .required = 1
        },
-       {
+       [AO_RADIO_CAL_INDEX] = {
                .name = "ao_radio_cal",
                .required = 0
        },
-       {
+       [AO_USB_DESCRIPTORS_INDEX] = {
                .name = "ao_usb_descriptors",
                .required = 0
        },
@@ -58,13 +58,6 @@ rewrite(struct ao_hex_image *load, unsigned address, uint8_t *data, int length)
        if (address < load->address || load->address + load->length < address + length)
                return false;
 
-       printf("rewrite %04x:", address);
-       for (i = 0; i < length; i++)
-               printf (" %02x", load->data[address - load->address + i]);
-       printf(" ->");
-       for (i = 0; i < length; i++)
-               printf (" %02x", data[i]);
-       printf("\n");
        memcpy(load->data + address - load->address, data, length);
        return true;
 }
@@ -166,3 +159,81 @@ ao_editaltos(struct ao_hex_image *image,
        }
        return true;
 }
+
+static uint16_t
+read_le16(uint8_t *src)
+{
+       return (uint16_t) src[0] | ((uint16_t) src[1] << 8);
+}
+
+bool
+ao_heximage_usb_id(struct ao_hex_image *image, struct ao_usb_id *id)
+{
+       uint32_t        usb_descriptors;
+
+       if (!AO_USB_DESCRIPTORS)
+               return false;
+       usb_descriptors = AO_USB_DESCRIPTORS - image->address;
+
+       while (image->data[usb_descriptors] != 0 && usb_descriptors < image->length) {
+               if (image->data[usb_descriptors+1] == AO_USB_DESC_DEVICE) {
+                       break;
+               }
+               usb_descriptors += image->data[usb_descriptors];
+       }
+
+       /*
+        * check to make sure there's at least 0x12 (size of a USB
+        * device descriptor) available
+        */
+       if (usb_descriptors >= image->length || image->data[usb_descriptors] != 0x12)
+               return false;
+
+       id->vid = read_le16(image->data + usb_descriptors + 8);
+       id->pid = read_le16(image->data + usb_descriptors + 10);
+
+       return true;
+}
+
+uint16_t *
+ao_heximage_usb_product(struct ao_hex_image *image)
+{
+       uint32_t        usb_descriptors;
+       int             string_num;
+       uint16_t        *product;
+       uint8_t         product_len;
+
+       if (!AO_USB_DESCRIPTORS)
+               return NULL;
+       usb_descriptors = AO_USB_DESCRIPTORS - image->address;
+
+       string_num = 0;
+       while (image->data[usb_descriptors] != 0 && usb_descriptors < image->length) {
+               if (image->data[usb_descriptors+1] == AO_USB_DESC_STRING) {
+                       ++string_num;
+                       if (string_num == 3)
+                               break;
+               }
+               usb_descriptors += image->data[usb_descriptors];
+       }
+
+       /*
+        * check to make sure there's at least 0x12 (size of a USB
+        * device descriptor) available
+        */
+       if (usb_descriptors >= image->length || image->data[usb_descriptors] == 0)
+               return NULL;
+
+       product_len = image->data[usb_descriptors] - 2;
+
+       if (usb_descriptors < product_len + 2)
+               return NULL;
+
+       product = malloc (product_len + 2);
+       if (!product)
+               return NULL;
+
+       memcpy(product, image->data + usb_descriptors + 2, product_len);
+       product[product_len/2] = 0;
+       return product;
+}
index a480954..6f2829b 100644 (file)
 extern struct ao_sym ao_symbols[];
 extern int ao_num_symbols;
 
+#define AO_USB_DESC_DEVICE             1
 #define AO_USB_DESC_STRING             3
 
-#define AO_ROMCONFIG_VERSION   (ao_symbols[0].addr)
-#define AO_ROMCONFIG_CHECK     (ao_symbols[1].addr)
-#define AO_SERIAL_NUMBER       (ao_symbols[2].addr)
-#define AO_RADIO_CAL           (ao_symbols[3].addr)
-#define AO_USB_DESCRIPTORS     (ao_symbols[4].addr)
+#define AO_ROMCONFIG_VERSION_INDEX     0
+#define AO_ROMCONFIG_CHECK_INDEX       1
+#define AO_SERIAL_NUMBER_INDEX         2
+#define AO_RADIO_CAL_INDEX             3
+#define AO_USB_DESCRIPTORS_INDEX       4
+
+#define AO_ROMCONFIG_VERSION   (ao_symbols[AO_ROMCONFIG_VERSION_INDEX].addr)
+#define AO_ROMCONFIG_CHECK     (ao_symbols[AO_ROMCONFIG_CHECK_INDEX].addr)
+#define AO_SERIAL_NUMBER       (ao_symbols[AO_SERIAL_NUMBER_INDEX].addr)
+#define AO_RADIO_CAL           (ao_symbols[AO_RADIO_CAL_INDEX].addr)
+#define AO_USB_DESCRIPTORS     (ao_symbols[AO_USB_DESCRIPTORS_INDEX].addr)
 
 struct ao_editaltos_funcs {
        uint16_t        (*get_uint16)(void *closure, uint32_t addr);
        uint32_t        (*get_uint32)(void *closure, uint32_t addr);
 };
 
+struct ao_usb_id {
+       uint16_t        vid;
+       uint16_t        pid;
+};
+
 bool
 ao_editaltos_find_symbols(struct ao_sym *file_symbols, int num_file_symbols,
                          struct ao_sym *symbols, int num_symbols);
@@ -48,4 +60,10 @@ ao_editaltos(struct ao_hex_image *image,
             uint16_t serial,
             uint32_t radio_cal);
 
+bool
+ao_heximage_usb_id(struct ao_hex_image *image, struct ao_usb_id *id);
+
+uint16_t *
+ao_heximage_usb_product(struct ao_hex_image *image);
+
 #endif /* _AO_EDITALTOS_H_ */
index b4b878d..0a23dfd 100644 (file)
@@ -157,3 +157,40 @@ ao_self_get_uint32(struct cc_usb *cc, uint32_t addr)
        free(hex);
        return v;
 }
+
+bool
+ao_self_get_usb_id(struct cc_usb *cc, struct ao_usb_id *id)
+{
+       struct ao_hex_image     *hex;
+       bool                    ret;
+
+       if (!AO_USB_DESCRIPTORS)
+               return false;
+
+       hex = ao_self_read(cc, AO_USB_DESCRIPTORS, 512);
+       if (!hex)
+               return false;
+
+       ret = ao_heximage_usb_id(hex, id);
+       free(hex);
+       return ret;
+}
+
+uint16_t *
+ao_self_get_usb_product(struct cc_usb *cc)
+{
+       struct ao_hex_image     *hex;
+       uint16_t                *ret;
+
+       if (!AO_USB_DESCRIPTORS)
+               return NULL;
+
+       hex = ao_self_read(cc, AO_USB_DESCRIPTORS, 512);
+       if (!hex)
+               return NULL;
+
+       ret = ao_heximage_usb_product(hex);
+       free(hex);
+       return ret;
+}
+
index 71529d1..4aac584 100644 (file)
@@ -22,6 +22,7 @@
 #include <stdbool.h>
 #include "ao-hex.h"
 #include "cc-usb.h"
+#include "ao-editaltos.h"
 
 struct ao_hex_image *
 ao_self_read(struct cc_usb *cc, uint32_t address, uint32_t length);
@@ -35,4 +36,13 @@ ao_self_get_uint16(struct cc_usb *cc, uint32_t addr);
 uint32_t
 ao_self_get_uint32(struct cc_usb *cc, uint32_t addr);
 
+bool
+ao_self_get_usb_id(struct cc_usb *cc, struct ao_usb_id *id);
+
+uint16_t *
+ao_self_get_usb_product(struct cc_usb *cc);
+
+uint16_t *
+ao_self_get_usb_product(struct cc_usb *cc);
+
 #endif /* _AO_SELFLOAD_H_ */