ao-stmload: Always round up load amount to 4 byte boundary
[fw/altos] / ao-tools / ao-stmload / ao-stmload.c
1 /*
2  * Copyright © 2012 Keith Packard <keithp@keithp.com>
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation; version 2 of the License.
7  *
8  * This program is distributed in the hope that it will be useful, but
9  * WITHOUT ANY WARRANTY; without even the implied warranty of
10  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11  * General Public License for more details.
12  *
13  * You should have received a copy of the GNU General Public License along
14  * with this program; if not, write to the Free Software Foundation, Inc.,
15  * 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
16  */
17
18 #include <err.h>
19 #include <fcntl.h>
20 #include <gelf.h>
21 #include <stdio.h>
22 #include <stdint.h>
23 #include <stdlib.h>
24 #include <sysexits.h>
25 #include <unistd.h>
26 #include <getopt.h>
27 #include <string.h>
28 #include "stlink-common.h"
29
30 #define AO_USB_DESC_STRING              3
31
32 struct sym {
33         unsigned        addr;
34         char            *name;
35         int             required;
36 } ao_symbols[] = {
37
38         { 0,    "ao_romconfig_version", 1 },
39 #define AO_ROMCONFIG_VERSION    (ao_symbols[0].addr)
40
41         { 0,    "ao_romconfig_check",   1 },
42 #define AO_ROMCONFIG_CHECK      (ao_symbols[1].addr)
43
44         { 0,    "ao_serial_number", 1 },
45 #define AO_SERIAL_NUMBER        (ao_symbols[2].addr)
46
47         { 0,    "ao_usb_descriptors", 0 },
48 #define AO_USB_DESCRIPTORS      (ao_symbols[3].addr)
49
50         { 0,    "ao_radio_cal", 0 },
51 #define AO_RADIO_CAL            (ao_symbols[4].addr)
52 };
53
54 #define NUM_SYMBOLS             5
55 #define NUM_REQUIRED_SYMBOLS    3
56
57 /*
58  * Look through the Elf file for the AltOS symbols
59  * that can be adjusted before the image is written
60  * to the device
61  */
62 static int
63 find_symbols (Elf *e)
64 {
65         Elf_Scn         *scn;
66         Elf_Data        *symbol_data = NULL;
67         GElf_Shdr       shdr;
68         GElf_Sym        sym;
69         int             i, symbol_count, s;
70         int             required = 0;
71         char            *symbol_name;
72
73         /*
74          * Find the symbols
75          */
76
77         scn = NULL;
78         while ((scn = elf_nextscn(e, scn)) != NULL) {
79                 if (gelf_getshdr(scn, &shdr) != &shdr)
80                         return 0;
81
82                 if (shdr.sh_type == SHT_SYMTAB) {
83                         symbol_data = elf_getdata(scn, NULL);
84                         symbol_count = shdr.sh_size / shdr.sh_entsize;
85                         break;
86                 }
87         }
88
89         if (!symbol_data)
90                 return 0;
91
92         for (i = 0; i < symbol_count; i++) {
93                 gelf_getsym(symbol_data, i, &sym);
94
95                 symbol_name = elf_strptr(e, shdr.sh_link, sym.st_name);
96
97                 for (s = 0; s < NUM_SYMBOLS; s++)
98                         if (!strcmp (ao_symbols[s].name, symbol_name)) {
99                                 int     t;
100                                 ao_symbols[s].addr = sym.st_value;
101                                 if (ao_symbols[s].required)
102                                         ++required;
103                         }
104         }
105
106         return required >= NUM_REQUIRED_SYMBOLS;
107 }
108
109 struct load {
110         uint32_t        addr;
111         uint32_t        len;
112         uint8_t         buf[0];
113 };
114
115 uint32_t round4(uint32_t a) {
116         return (a + 3) & ~3;
117 }
118
119 struct load *
120 new_load (uint32_t addr, uint32_t len)
121 {
122         struct load *new;
123
124         len = round4(len);
125         new = calloc (1, sizeof (struct load) + len);
126         if (!new)
127                 abort();
128
129         new->addr = addr;
130         new->len = len;
131         return new;
132 }
133
134 void
135 load_paste(struct load *into, struct load *from)
136 {
137         if (from->addr < into->addr || into->addr + into->len < from->addr + from->len)
138                 abort();
139
140         memcpy(into->buf + from->addr - into->addr, from->buf, from->len);
141 }
142
143 /*
144  * Make a new load structure large enough to hold the old one and
145  * the new data
146  */
147 struct load *
148 expand_load(struct load *from, uint32_t addr, uint32_t len)
149 {
150         struct load     *new;
151
152         if (from) {
153                 uint32_t        from_last = from->addr + from->len;
154                 uint32_t        last = addr + len;
155
156                 if (addr > from->addr)
157                         addr = from->addr;
158                 if (last < from_last)
159                         last = from_last;
160
161                 len = last - addr;
162
163                 if (addr == from->addr && len == from->len)
164                         return from;
165         }
166         new = new_load(addr, len);
167         if (from) {
168                 load_paste(new, from);
169                 free (from);
170         }
171         return new;
172 }
173
174 /*
175  * Create a new load structure with data from the existing one
176  * and the new data
177  */
178 struct load *
179 load_write(struct load *from, uint32_t addr, uint32_t len, void *data)
180 {
181         struct load     *new;
182
183         new = expand_load(from, addr, len);
184         memcpy(new->buf + addr - new->addr, data, len);
185         return new;
186 }
187
188 /*
189  * Construct a large in-memory block for all
190  * of the loaded sections of the program
191  */
192 static struct load *
193 get_load(Elf *e)
194 {
195         Elf_Scn         *scn;
196         size_t          shstrndx;
197         GElf_Shdr       shdr;
198         Elf_Data        *data;
199         uint8_t         *buf;
200         char            *got_name;
201         size_t          nphdr;
202         int             p;
203         GElf_Phdr       phdr;
204         struct load     *load = NULL;
205         
206         if (elf_getshdrstrndx(e, &shstrndx) < 0)
207                 return 0;
208
209         if (elf_getphdrnum(e, &nphdr) < 0)
210                 return 0;
211
212         /*
213          * As far as I can tell, all of the phdr sections should
214          * be flashed to memory
215          */
216         for (p = 0; p < nphdr; p++) {
217
218                 /* Find this phdr */
219                 gelf_getphdr(e, p, &phdr);
220
221                 /* Get the associated file section */
222                 scn = gelf_offscn(e, phdr.p_offset);
223
224                 if (gelf_getshdr(scn, &shdr) != &shdr)
225                         abort();
226
227                 data = elf_getdata(scn, NULL);
228
229                 /* Write the section data into the memory block */
230                 load = load_write(load, phdr.p_paddr, phdr.p_filesz, data->d_buf);
231         }
232         return load;
233 }
234
235 /*
236  * Edit the to-be-written memory block
237  */
238 static int
239 rewrite(struct load *load, unsigned addr, uint8_t *data, int len)
240 {
241         int             i;
242
243         if (addr < load->addr || load->addr + load->len < addr + len)
244                 return 0;
245
246         printf("rewrite %04x:", addr);
247         for (i = 0; i < len; i++)
248                 printf (" %02x", load->buf[addr - load->addr + i]);
249         printf(" ->");
250         for (i = 0; i < len; i++)
251                 printf (" %02x", data[i]);
252         printf("\n");
253         memcpy(load->buf + addr - load->addr, data, len);
254 }
255
256 /*
257  * Open the specified ELF file and
258  * check for the symbols we need
259  */
260
261 Elf *
262 ao_open_elf(char *name)
263 {
264         int             fd;
265         Elf             *e;
266         Elf_Scn         *scn;
267         Elf_Data        *symbol_data = NULL;
268         GElf_Shdr       shdr;
269         GElf_Sym        sym;
270         size_t          n, shstrndx, sz;
271         int             i, symbol_count, s;
272         int             required = 0;
273
274         if (elf_version(EV_CURRENT) == EV_NONE)
275                 return NULL;
276
277         fd = open(name, O_RDONLY, 0);
278
279         if (fd < 0)
280                 return NULL;
281
282         e = elf_begin(fd, ELF_C_READ, NULL);
283
284         if (!e)
285                 return NULL;
286
287         if (elf_kind(e) != ELF_K_ELF)
288                 return NULL;
289
290         if (elf_getshdrstrndx(e, &shstrndx) != 0)
291                 return NULL;
292
293         if (!find_symbols(e)) {
294                 fprintf (stderr, "Cannot find required symbols\n");
295                 return NULL;
296         }
297
298         return e;
299 }
300
301 /*
302  * Read a 32-bit value from the target device with arbitrary
303  * alignment
304  */
305 static uint32_t
306 get_uint32(stlink_t *sl, uint32_t addr)
307 {
308         const           uint8_t *data = sl->q_buf;
309         uint32_t        actual_addr;
310         int             off;
311         uint32_t        result;
312
313         sl->q_len = 0;
314
315         printf ("read 0x%x\n", addr);
316
317         actual_addr = addr & ~3;
318         
319         stlink_read_mem32(sl, actual_addr, 8);
320
321         if (sl->q_len != 8)
322                 abort();
323
324         off = addr & 3;
325         result = data[off] | (data[off + 1] << 8) | (data[off+2] << 16) | (data[off+3] << 24);
326         printf ("read 0x%08x = 0x%08x\n", addr, result);
327         return result;
328 }
329
330 /*
331  * Read a 16-bit value from the target device with arbitrary
332  * alignment
333  */
334 static uint16_t
335 get_uint16(stlink_t *sl, uint32_t addr)
336 {
337         const           uint8_t *data = sl->q_buf;
338         uint32_t        actual_addr;
339         int             off;
340         uint16_t        result;
341
342         sl->q_len = 0;
343
344
345         actual_addr = addr & ~3;
346         
347         stlink_read_mem32(sl, actual_addr, 8);
348
349         if (sl->q_len != 8)
350                 abort();
351
352         off = addr & 3;
353         result = data[off] | (data[off + 1] << 8);
354         printf ("read 0x%08x = 0x%04x\n", addr, result);
355         return result;
356 }
357
358 /*
359  * Check to see if the target device has been
360  * flashed with a similar firmware image before
361  *
362  * This is done by looking for the same romconfig version,
363  * which should be at the same location as the linker script
364  * places this at 0x100 from the start of the rom section
365  */
366 static int
367 check_flashed(stlink_t *sl)
368 {
369         uint16_t        romconfig_version = get_uint16(sl, AO_ROMCONFIG_VERSION);
370         uint16_t        romconfig_check = get_uint16(sl, AO_ROMCONFIG_CHECK);
371
372         if (romconfig_version != (uint16_t) ~romconfig_check) {
373                 fprintf (stderr, "Device has not been flashed before\n");
374                 return 0;
375         }
376         return 1;
377 }
378
379 static const struct option options[] = {
380         { .name = "device", .has_arg = 1, .val = 'D' },
381         { .name = "cal", .has_arg = 1, .val = 'c' },
382         { .name = "serial", .has_arg = 1, .val = 's' },
383         { 0, 0, 0, 0},
384 };
385
386 static void usage(char *program)
387 {
388         fprintf(stderr, "usage: %s [--cal=<radio-cal>] [--serial=<serial>] file.elf\n", program);
389         exit(1);
390 }
391
392 void
393 done(stlink_t *sl, int code)
394 {
395         if (sl) {
396                 stlink_reset(sl);
397                 stlink_run(sl);
398                 stlink_exit_debug_mode(sl);
399                 stlink_close(sl);
400         }
401         exit (code);
402 }
403
404 int
405 main (int argc, char **argv)
406 {
407         char                    *device = NULL;
408         char                    *filename;
409         Elf                     *e;
410         char                    *serial_end;
411         unsigned int            serial = 0;
412         char                    *serial_ucs2;
413         int                     serial_ucs2_len;
414         char                    serial_int[2];
415         unsigned int            s;
416         int                     i;
417         int                     string_num;
418         uint32_t                cal = 0;
419         char                    cal_int[4];
420         char                    *cal_end;
421         int                     c;
422         stlink_t                *sl;
423         int                     was_flashed = 0;
424         struct load             *load;
425
426         while ((c = getopt_long(argc, argv, "D:c:s:", options, NULL)) != -1) {
427                 switch (c) {
428                 case 'D':
429                         device = optarg;
430                         break;
431                 case 'c':
432                         cal = strtoul(optarg, &cal_end, 10);
433                         if (cal_end == optarg || *cal_end != '\0')
434                                 usage(argv[0]);
435                         break;
436                 case 's':
437                         serial = strtoul(optarg, &serial_end, 10);
438                         if (serial_end == optarg || *serial_end != '\0')
439                                 usage(argv[0]);
440                         break;
441                 default:
442                         usage(argv[0]);
443                         break;
444                 }
445         }
446
447         filename = argv[optind];
448         if (filename == NULL)
449                 usage(argv[0]);
450
451         /*
452          * Open the source file and load the symbols and
453          * flash data
454          */
455         
456         e = ao_open_elf(filename);
457         if (!e) {
458                 fprintf(stderr, "Cannot open file \"%s\"\n", filename);
459                 exit(1);
460         }
461
462         if (!find_symbols(e)) {
463                 fprintf(stderr, "Cannot find symbols in \"%s\"\n", filename);
464                 exit(1);
465         }
466
467         if (!(load = get_load(e))) {
468                 fprintf(stderr, "Cannot find program data in \"%s\"\n", filename);
469                 exit(1);
470         }
471                 
472         /* Connect to the programming dongle
473          */
474         
475         if (device) {
476                 sl = stlink_v1_open(50);
477         } else {
478                 sl = stlink_open_usb(50);
479                 
480         }
481         if (!sl) {
482                 fprintf (stderr, "No STLink devices present\n");
483                 done (sl, 1);
484         }
485
486         sl->verbose = 50;
487
488         /* Verify that the loaded image fits entirely within device flash
489          */
490         if (load->addr < sl->flash_base ||
491             sl->flash_base + sl->flash_size < load->addr + load->len) {
492                 fprintf (stderr, "\%s\": Invalid memory range 0x%08x - 0x%08x\n", filename,
493                          load->addr, load->addr + load->len);
494                 done(sl, 1);
495         }
496
497         /* Enter debugging mode
498          */
499         if (stlink_current_mode(sl) == STLINK_DEV_DFU_MODE)
500                 stlink_exit_dfu_mode(sl);
501
502         if (stlink_current_mode(sl) != STLINK_DEV_DEBUG_MODE)
503                 stlink_enter_swd_mode(sl);
504
505         /* Go fetch existing config values
506          * if available
507          */
508         was_flashed = check_flashed(sl);
509
510         if (!serial) {
511                 if (!was_flashed) {
512                         fprintf (stderr, "Must provide serial number\n");
513                         done(sl, 1);
514                 }
515                 serial = get_uint16(sl, AO_SERIAL_NUMBER);
516                 if (!serial || serial == 0xffff) {
517                         fprintf (stderr, "Invalid existing serial %d\n", serial);
518                         done(sl, 1);
519                 }
520         }
521
522         if (!cal && AO_RADIO_CAL && was_flashed) {
523                 cal = get_uint32(sl, AO_RADIO_CAL);
524                 if (!cal || cal == 0xffffffff) {
525                         fprintf (stderr, "Invalid existing rf cal %d\n", cal);
526                         done(sl, 1);
527                 }
528         }
529
530         /* Write the config values into the flash image
531          */
532
533         serial_int[0] = serial & 0xff;
534         serial_int[1] = (serial >> 8) & 0xff;
535
536         if (!rewrite(load, AO_SERIAL_NUMBER, serial_int, sizeof (serial_int))) {
537                 fprintf(stderr, "Cannot rewrite serial integer at %08x\n",
538                         AO_SERIAL_NUMBER);
539                 done(sl, 1);
540         }
541
542         if (AO_USB_DESCRIPTORS) {
543                 unsigned        usb_descriptors;
544                 usb_descriptors = AO_USB_DESCRIPTORS - load->addr;
545                 string_num = 0;
546
547                 while (load->buf[usb_descriptors] != 0 && usb_descriptors < load->len) {
548                         if (load->buf[usb_descriptors+1] == AO_USB_DESC_STRING) {
549                                 ++string_num;
550                                 if (string_num == 4)
551                                         break;
552                         }
553                         usb_descriptors += load->buf[usb_descriptors];
554                 }
555                 if (usb_descriptors >= load->len || load->buf[usb_descriptors] == 0 ) {
556                         fprintf(stderr, "Cannot rewrite serial string at %08x\n", AO_USB_DESCRIPTORS);
557                         done(sl, 1);
558                 }
559
560                 serial_ucs2_len = load->buf[usb_descriptors] - 2;
561                 serial_ucs2 = malloc(serial_ucs2_len);
562                 if (!serial_ucs2) {
563                         fprintf(stderr, "Malloc(%d) failed\n", serial_ucs2_len);
564                         done(sl, 1);
565                 }
566                 s = serial;
567                 for (i = serial_ucs2_len / 2; i; i--) {
568                         serial_ucs2[i * 2 - 1] = 0;
569                         serial_ucs2[i * 2 - 2] = (s % 10) + '0';
570                         s /= 10;
571                 }
572                 if (!rewrite(load, usb_descriptors + 2 + load->addr, serial_ucs2, serial_ucs2_len)) {
573                         fprintf (stderr, "Cannot rewrite USB descriptor at %08x\n", AO_USB_DESCRIPTORS);
574                         done(sl, 1);
575                 }
576         }
577
578         if (cal && AO_RADIO_CAL) {
579                 cal_int[0] = cal & 0xff;
580                 cal_int[1] = (cal >> 8) & 0xff;
581                 cal_int[2] = (cal >> 16) & 0xff;
582                 cal_int[3] = (cal >> 24) & 0xff;
583
584                 if (!rewrite(load, AO_RADIO_CAL, cal_int, sizeof (cal_int))) {
585                         fprintf(stderr, "Cannot rewrite radio calibration at %08x\n", AO_RADIO_CAL);
586                         exit(1);
587                 }
588         }
589
590         /* And flash the resulting image to the device
591          */
592         if (stlink_write_flash(sl, load->addr, load->buf, load->len) < 0) {
593                 fprintf (stderr, "\"%s\": Write failed\n", filename);
594                 done(sl, 1);
595         }
596
597         done(sl, 0);
598 }